44from typing import Callable , Sequence , Tuple
55
66import torch
7- from sdpa_converter import *
87from torch_tensorrt .dynamo ._settings import CompilationSettings
98from torch_tensorrt .dynamo .conversion .aten_ops_converters import args_bounds_check
109from torch_tensorrt .dynamo .lowering import TORCH_TRT_DECOMPOSITIONS
1514 clean_up_graph_after_modifications ,
1615)
1716
17+ from .sdpa_converter import *
18+
1819logger = logging .getLogger (__name__ )
1920
2021# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention
2122# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it.
22- TORCH_TRT_DECOMPOSITIONS .pop (torch .ops .aten .scaled_dot_product_attention .default )
23+ TORCH_TRT_DECOMPOSITIONS .pop (torch .ops .aten .scaled_dot_product_attention .default , None )
24+ TORCH_TRT_DECOMPOSITIONS .pop (
25+ torch .ops .aten ._scaled_dot_product_efficient_attention .default , None
26+ )
2327TORCH_TRT_DECOMPOSITIONS .pop (
24- torch .ops .aten ._scaled_dot_product_efficient_attention .default
28+ torch .ops .aten ._scaled_dot_product_flash_attention .default , None
2529)
26- TORCH_TRT_DECOMPOSITIONS .pop (torch .ops .aten ._scaled_dot_product_flash_attention .default )
2730
2831REPLACEABLE_ATEN_OPS = {
2932 torch .ops .aten ._scaled_dot_product_efficient_attention .default ,
@@ -59,6 +62,7 @@ def replace_variants_of_sdpa(
5962 elif len (node .args ) == 5 :
6063 query , key , value , attn_mask , is_causal = node .args
6164 dropout_p = 0.0
65+
6266 else :
6367 raise ValueError (
6468 f"Unexpected number of arguments for { node .target } in the graph"
@@ -71,6 +75,8 @@ def replace_variants_of_sdpa(
7175 query , key , value , dropout_p , is_causal , return_debug_mask = (
7276 node .args
7377 )
78+ if len (node .args ) == 5 :
79+ query , key , value , dropout_p , is_causal = node .args
7480 elif len (node .args ) == 3 :
7581 query , key , value = node .args
7682 dropout_p = 0.0
@@ -79,20 +85,21 @@ def replace_variants_of_sdpa(
7985 raise ValueError (
8086 f"Unexpected number of arguments for { node .target } in the graph"
8187 )
82- if attn_mask is not None :
83- logger .warning (
84- f"This current version of SDPA converter does not support attn_mask for { node .target } in the graph. Ignoring it and using is_causal=True configuration."
85- )
86-
87- modified_input_args = (query , key , value , None , dropout_p , is_causal )
8888
89+ logger .warning (
90+ f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
91+ )
92+ modified_input_args = (query , key , value , None , dropout_p , True )
8993 # Create a new node with torch.nn.functional.scaled_dot_product_attention
9094 # The input args is (query, key, value, is_causal). kwargs has scale
9195 with gm .graph .inserting_after (node ):
9296 new_node = gm .graph .call_function (
9397 torch .nn .functional .scaled_dot_product_attention ,
9498 args = modified_input_args ,
95- kwargs = {"scale" : node .kwargs .get ("scale" , None )},
99+ kwargs = {
100+ "scale" : node .kwargs .get ("scale" , None ),
101+ "use_fp32_acc" : settings .use_fp32_acc ,
102+ },
96103 )
97104
98105 # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
@@ -113,7 +120,7 @@ def replace_variants_of_sdpa(
113120 # Clean up the graph
114121 clean_up_graph_after_modifications (gm )
115122
116- logger .info (
123+ logger .debug (
117124 "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
118125 )
119126 return gm
0 commit comments