@@ -476,14 +476,20 @@ def _sfdp_params_check(match):
476476 return True
477477
478478
479- def _sfdp_scale_factor_check (scale_factor_op ):
479+ def _sfdp_extra_check (scale_factor_op , disable_cuda = False ):
480480 def fn (match ):
481481 scale_factor_node = filter_nodes (match .nodes , scale_factor_op )[0 ]
482482 # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns.
483483 scale_factor = scale_factor_node .args [1 ]
484484 # make sure the scale_factor a float/int. SymInt?
485485 if not isinstance (scale_factor , (float , int )):
486486 return False
487+ if (
488+ disable_cuda
489+ and "query" in match .kwargs
490+ and "cuda" in str (match .kwargs ["query" ].meta ["val" ].device )
491+ ):
492+ return False
487493 return _sfdp_params_check (match )
488494
489495 return fn
@@ -555,28 +561,28 @@ def _get_sfdp_patterns():
555561 _sfdp_replacement_1 ,
556562 [g (), g (), g (), c ()],
557563 {},
558- _sfdp_scale_factor_check (aten .div .Tensor ),
564+ _sfdp_extra_check (aten .div .Tensor ),
559565 ),
560566 (
561567 _sfdp_pattern_2 ,
562568 _sfdp_replacement_2 ,
563569 [g (), g (), g (), c ()],
564570 {},
565- _sfdp_scale_factor_check (aten .mul .Tensor ),
571+ _sfdp_extra_check (aten .mul .Tensor ),
566572 ),
567573 (
568574 _sfdp_pattern_3 ,
569575 _sfdp_replacement_3 ,
570576 [g (), g (), g (), c ()],
571577 d ,
572- _sfdp_scale_factor_check (aten .div .Tensor ),
578+ _sfdp_extra_check (aten .div .Tensor ),
573579 ),
574580 (
575581 _sfdp_pattern_4 ,
576582 _sfdp_replacement_4 ,
577583 [g (), g (), g (), c ()],
578584 d ,
579- _sfdp_scale_factor_check (aten .mul .Tensor ),
585+ _sfdp_extra_check (aten .mul .Tensor ),
580586 ),
581587 (
582588 _sfdp_pattern_5 ,
@@ -625,14 +631,14 @@ def _get_sfdp_patterns():
625631 _sfdp_replacement_11 ,
626632 [g (), g (), g (), c ()],
627633 {},
628- _sfdp_scale_factor_check (aten .div .Tensor ),
634+ _sfdp_extra_check (aten .div .Tensor ),
629635 ),
630636 (
631637 _sfdp_pattern_12 ,
632638 _sfdp_replacement_12 ,
633639 [g (), g (), g (), c ()],
634640 d ,
635- _sfdp_scale_factor_check (aten .div .Tensor ),
641+ _sfdp_extra_check (aten .div .Tensor ),
636642 ),
637643 (
638644 _sfdp_pattern_13 ,
@@ -646,28 +652,29 @@ def _get_sfdp_patterns():
646652 _sfdp_replacement_14 ,
647653 [g (), g (), g (), m (), c ()],
648654 {},
649- _sfdp_scale_factor_check (aten .div .Tensor ),
655+ _sfdp_extra_check (aten .div .Tensor ),
650656 ),
651657 (
652658 _sfdp_pattern_15 ,
653659 _sfdp_replacement_15 ,
654660 [g (), g (), g (), m (), c ()],
655661 {},
656- _sfdp_scale_factor_check (aten .div .Tensor ),
662+ _sfdp_extra_check (aten .div .Tensor ),
657663 ),
664+ # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention
658665 (
659666 _sfdp_pattern_16 ,
660667 _sfdp_replacement_16 ,
661668 [g (), g (), g (), m (), c ()],
662669 d ,
663- _sfdp_scale_factor_check (aten .div .Tensor ),
670+ _sfdp_extra_check (aten .div .Tensor , disable_cuda = True ),
664671 ),
665672 (
666673 _sfdp_pattern_17 ,
667674 _sfdp_replacement_17 ,
668675 [g (), g (), g (), m (), c ()],
669676 d ,
670- _sfdp_scale_factor_check (aten .div .Tensor ),
677+ _sfdp_extra_check (aten .div .Tensor ),
671678 ),
672679 ]:
673680 # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
0 commit comments