Skip to content

Commit 8a3afde

Browse files
committed
test
1 parent e4b126f commit 8a3afde

4 files changed

Lines changed: 23 additions & 16 deletions

File tree

examples/apps/flux_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def main(args):
272272
parser.add_argument(
273273
"--fp4_mha",
274274
action="store_true",
275-
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG",
275+
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG",
276276
)
277277
parser.add_argument(
278278
"--low_vram_mode",

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def _populate_trt_builder_config(
220220
if version.parse(trt.__version__) >= version.parse("8.2"):
221221
builder_config.profiling_verbosity = (
222222
trt.ProfilingVerbosity.DETAILED
223-
# if self._debugger_config and self._debugger_config.save_engine_profile
224-
# else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
223+
if self._debugger_config and self._debugger_config.save_engine_profile
224+
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
225225
)
226226

227227
if version.parse(trt.__version__) >= version.parse("8.6"):

tools/llm/torchtrt_ext/register_sdpa.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Callable, Sequence, Tuple
55

66
import torch
7-
from sdpa_converter import *
87
from torch_tensorrt.dynamo._settings import CompilationSettings
98
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
109
from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS
@@ -15,15 +14,19 @@
1514
clean_up_graph_after_modifications,
1615
)
1716

17+
from .sdpa_converter import *
18+
1819
logger = logging.getLogger(__name__)
1920

2021
# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention
2122
# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it.
22-
TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default)
23+
TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None)
24+
TORCH_TRT_DECOMPOSITIONS.pop(
25+
torch.ops.aten._scaled_dot_product_efficient_attention.default, None
26+
)
2327
TORCH_TRT_DECOMPOSITIONS.pop(
24-
torch.ops.aten._scaled_dot_product_efficient_attention.default
28+
torch.ops.aten._scaled_dot_product_flash_attention.default, None
2529
)
26-
TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default)
2730

2831
REPLACEABLE_ATEN_OPS = {
2932
torch.ops.aten._scaled_dot_product_efficient_attention.default,
@@ -59,6 +62,7 @@ def replace_variants_of_sdpa(
5962
elif len(node.args) == 5:
6063
query, key, value, attn_mask, is_causal = node.args
6164
dropout_p = 0.0
65+
6266
else:
6367
raise ValueError(
6468
f"Unexpected number of arguments for {node.target} in the graph"
@@ -71,6 +75,8 @@ def replace_variants_of_sdpa(
7175
query, key, value, dropout_p, is_causal, return_debug_mask = (
7276
node.args
7377
)
78+
if len(node.args) == 5:
79+
query, key, value, dropout_p, is_causal = node.args
7480
elif len(node.args) == 3:
7581
query, key, value = node.args
7682
dropout_p = 0.0
@@ -79,20 +85,21 @@ def replace_variants_of_sdpa(
7985
raise ValueError(
8086
f"Unexpected number of arguments for {node.target} in the graph"
8187
)
82-
if attn_mask is not None:
83-
logger.warning(
84-
f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration."
85-
)
86-
87-
modified_input_args = (query, key, value, None, dropout_p, is_causal)
8888

89+
logger.warning(
90+
f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
91+
)
92+
modified_input_args = (query, key, value, None, dropout_p, True)
8993
# Create a new node with torch.nn.functional.scaled_dot_product_attention
9094
# The input args is (query, key, value, is_causal). kwargs has scale
9195
with gm.graph.inserting_after(node):
9296
new_node = gm.graph.call_function(
9397
torch.nn.functional.scaled_dot_product_attention,
9498
args=modified_input_args,
95-
kwargs={"scale": node.kwargs.get("scale", None)},
99+
kwargs={
100+
"scale": node.kwargs.get("scale", None),
101+
"use_fp32_acc": settings.use_fp32_acc,
102+
},
96103
)
97104

98105
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
@@ -113,7 +120,7 @@ def replace_variants_of_sdpa(
113120
# Clean up the graph
114121
clean_up_graph_after_modifications(gm)
115122

116-
logger.info(
123+
logger.debug(
117124
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
118125
)
119126
return gm

tools/perf/Flux/flux_perf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main(args):
7373
parser.add_argument(
7474
"--fp4_mha",
7575
action="store_true",
76-
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG",
76+
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG",
7777
)
7878
parser.add_argument(
7979
"--low_vram_mode",

0 commit comments

Comments
 (0)