Skip to content

Commit 7591c0d

Browse files
committed
disable sfdp_pattern_16 CUDA
1 parent 511ba22 commit 7591c0d

2 files changed

Lines changed: 18 additions & 14 deletions

File tree

test/inductor/test_fused_attention.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,9 +746,6 @@ class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate):
746746
test_sdpa_rewriter_15_cuda = functools.partialmethod(
747747
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
748748
)
749-
test_sdpa_rewriter_16_cuda = functools.partialmethod(
750-
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_16
751-
)
752749
test_sdpa_rewriter_17_cuda = functools.partialmethod(
753750
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17
754751
)

torch/_inductor/fx_passes/fuse_attention.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -476,14 +476,20 @@ def _sfdp_params_check(match):
476476
return True
477477

478478

479-
def _sfdp_scale_factor_check(scale_factor_op):
479+
def _sfdp_extra_check(scale_factor_op, disable_cuda=False):
480480
def fn(match):
481481
scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0]
482482
# Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns.
483483
scale_factor = scale_factor_node.args[1]
484484
# make sure the scale_factor a float/int. SymInt?
485485
if not isinstance(scale_factor, (float, int)):
486486
return False
487+
if (
488+
disable_cuda
489+
and "query" in match.kwargs
490+
and "cuda" in str(match.kwargs["query"].meta["val"].device)
491+
):
492+
return False
487493
return _sfdp_params_check(match)
488494

489495
return fn
@@ -555,28 +561,28 @@ def _get_sfdp_patterns():
555561
_sfdp_replacement_1,
556562
[g(), g(), g(), c()],
557563
{},
558-
_sfdp_scale_factor_check(aten.div.Tensor),
564+
_sfdp_extra_check(aten.div.Tensor),
559565
),
560566
(
561567
_sfdp_pattern_2,
562568
_sfdp_replacement_2,
563569
[g(), g(), g(), c()],
564570
{},
565-
_sfdp_scale_factor_check(aten.mul.Tensor),
571+
_sfdp_extra_check(aten.mul.Tensor),
566572
),
567573
(
568574
_sfdp_pattern_3,
569575
_sfdp_replacement_3,
570576
[g(), g(), g(), c()],
571577
d,
572-
_sfdp_scale_factor_check(aten.div.Tensor),
578+
_sfdp_extra_check(aten.div.Tensor),
573579
),
574580
(
575581
_sfdp_pattern_4,
576582
_sfdp_replacement_4,
577583
[g(), g(), g(), c()],
578584
d,
579-
_sfdp_scale_factor_check(aten.mul.Tensor),
585+
_sfdp_extra_check(aten.mul.Tensor),
580586
),
581587
(
582588
_sfdp_pattern_5,
@@ -625,14 +631,14 @@ def _get_sfdp_patterns():
625631
_sfdp_replacement_11,
626632
[g(), g(), g(), c()],
627633
{},
628-
_sfdp_scale_factor_check(aten.div.Tensor),
634+
_sfdp_extra_check(aten.div.Tensor),
629635
),
630636
(
631637
_sfdp_pattern_12,
632638
_sfdp_replacement_12,
633639
[g(), g(), g(), c()],
634640
d,
635-
_sfdp_scale_factor_check(aten.div.Tensor),
641+
_sfdp_extra_check(aten.div.Tensor),
636642
),
637643
(
638644
_sfdp_pattern_13,
@@ -646,28 +652,29 @@ def _get_sfdp_patterns():
646652
_sfdp_replacement_14,
647653
[g(), g(), g(), m(), c()],
648654
{},
649-
_sfdp_scale_factor_check(aten.div.Tensor),
655+
_sfdp_extra_check(aten.div.Tensor),
650656
),
651657
(
652658
_sfdp_pattern_15,
653659
_sfdp_replacement_15,
654660
[g(), g(), g(), m(), c()],
655661
{},
656-
_sfdp_scale_factor_check(aten.div.Tensor),
662+
_sfdp_extra_check(aten.div.Tensor),
657663
),
664+
# TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention
658665
(
659666
_sfdp_pattern_16,
660667
_sfdp_replacement_16,
661668
[g(), g(), g(), m(), c()],
662669
d,
663-
_sfdp_scale_factor_check(aten.div.Tensor),
670+
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
664671
),
665672
(
666673
_sfdp_pattern_17,
667674
_sfdp_replacement_17,
668675
[g(), g(), g(), m(), c()],
669676
d,
670-
_sfdp_scale_factor_check(aten.div.Tensor),
677+
_sfdp_extra_check(aten.div.Tensor),
671678
),
672679
]:
673680
# XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern

0 commit comments

Comments
 (0)