Skip to content

Commit 2eac6ac

Browse files
committed
Remove 256 tile size as the performance regresses in some shapes
Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
1 parent de17fee commit 2eac6ac

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

torch/_inductor/template_heuristics/triton.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,17 +1554,17 @@ def __init__(self) -> None:
15541554
(torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4),
15551555
(torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4),
15561556
(torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4),
1557-
(torch.bfloat16, 64): ROCmFlexConfig(128, 64, 2, 8),
1558-
(torch.bfloat16, 128): ROCmFlexConfig(128, 64, 2, 8),
1559-
(torch.bfloat16, 256): ROCmFlexConfig(32, 64, 2, 8),
1557+
(torch.bfloat16, 64): ROCmFlexConfig(128, 64, 2, 4),
1558+
(torch.bfloat16, 128): ROCmFlexConfig(128, 64, 2, 4),
1559+
(torch.bfloat16, 256): ROCmFlexConfig(32, 64, 2, 4),
15601560
(torch.float16, 64): ROCmFlexConfig(128, 64, 2, 8),
15611561
(torch.float16, 128): ROCmFlexConfig(128, 64, 2, 8),
15621562
(torch.float16, 256): ROCmFlexConfig(32, 64, 2, 4),
15631563
}
15641564

15651565
self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [
15661566
ROCmFlexConfig(BLOCK1, BLOCK2, 1, w)
1567-
for BLOCK1 in [16, 64, 128, 256]
1567+
for BLOCK1 in [16, 64, 128]
15681568
for BLOCK2 in [16, 32, 64, 128]
15691569
for w in [4, 8]
15701570
]
@@ -1590,7 +1590,7 @@ def __init__(self) -> None:
15901590

15911591
self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [
15921592
ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu)
1593-
for BLOCK_M in [16, 32, 64, 128, 256]
1593+
for BLOCK_M in [16, 32, 64, 128]
15941594
for BLOCK_N in [32, 64, 128]
15951595
for num_stages in [1, 2]
15961596
for num_warps in [2, 4, 8]
@@ -1735,7 +1735,7 @@ def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfi
17351735
if dtype == torch.float32:
17361736
default_config = ROCmFlexConfig(64, 64, 1, 4)
17371737
else:
1738-
default_config = ROCmFlexConfig(128, 64, 2, 8)
1738+
default_config = ROCmFlexConfig(128, 64, 2, 4)
17391739
default_config = self.default_flex_config.get(
17401740
(dtype, head_dim), default_config
17411741
)

0 commit comments

Comments
 (0)