Skip to content

Commit d667ffe

Browse files
mstankov-amdpytorchmergebot
authored andcommitted
[ROCm][CI] Fix failing FP8 tests on RDNA4 (pytorch#174873)
## Summary This PR fixes FP8 inductor test failures that occur on AMD RDNA4 GPUs when testing matrix multiplications with small M dimensions (M < 16). ## Problem On gfx120x GPUs, FP8 scaled matrix multiplication tests fail with: - 92.4% NaN outputs when M < BLOCK_M (typically 16) - Large numerical mismatches between eager and compiled results - Only occurs in `max-autotune` mode **Root cause:** Autotuned Triton kernels on gfx120x generate incorrect tensor indexing for small M values, using partial indices instead of full computed indices in load/store operations. ## Solution - Added GPU-specific compile mode selection for small M values - gfx120x with M < 16: use `compile_mode="default"` - All other cases: use `compile_mode="max-autotune"` Pull Request resolved: pytorch#174873 Approved by: https://github.com/jeffdaily
1 parent fc90fdf commit d667ffe

1 file changed

Lines changed: 24 additions & 2 deletions

File tree

test/inductor/test_fp8.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,9 +1034,20 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
10341034
w_inverse_scale,
10351035
bias,
10361036
)
1037+
1038+
# On gfx120x, autotuned kernels have issues with small M
1039+
compile_mode = "max-autotune"
1040+
if (
1041+
torch.version.hip is not None
1042+
and M < 16
1043+
and torch.cuda.is_available()
1044+
and "gfx120" in torch.cuda.get_device_properties(0).gcnArchName
1045+
):
1046+
compile_mode = "default"
1047+
10371048
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
10381049
linear_compiled = torch.compile(
1039-
linear, backend="inductor", mode="max-autotune"
1050+
linear, backend="inductor", mode=compile_mode
10401051
)
10411052
y_compiled = linear_compiled(
10421053
x_fp8,
@@ -1334,9 +1345,20 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
13341345
w_inverse_scale,
13351346
bias,
13361347
)
1348+
1349+
# On gfx120x, autotuned kernels have issues with small M
1350+
compile_mode = "max-autotune"
1351+
if (
1352+
torch.version.hip is not None
1353+
and M < 16
1354+
and torch.cuda.is_available()
1355+
and "gfx120" in torch.cuda.get_device_properties(0).gcnArchName
1356+
):
1357+
compile_mode = "default"
1358+
13371359
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
13381360
linear_compiled = torch.compile(
1339-
linear, backend="inductor", mode="max-autotune"
1361+
linear, backend="inductor", mode=compile_mode
13401362
)
13411363
y_compiled = linear_compiled(
13421364
x_fp8,

0 commit comments

Comments
 (0)