Skip to content

Commit b7abe8e

Browse files
BoyuanFengpytorchmergebot
authored andcommitted
[BugFix] fix conditions to apply tma (#174480)
Summary: This diff fixes a wrong condition introduced in D92015529 for applying TMA. This should fix S618103. Test Plan: Test in mast job. Differential Revision: D92533457 Pull Request resolved: #174480 Approved by: https://github.com/Microve
1 parent ef9b83b commit b7abe8e

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

torch/_inductor/codegen/triton.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2233,11 +2233,13 @@ def can_use_tma(
22332233
return True
22342234
if not (
22352235
(
2236-
V.graph.get_current_device_or_throw().type == "cuda"
2237-
and torch.cuda.get_device_capability()[0] >= 9
2238-
and config.assume_aligned_inputs
2236+
(
2237+
V.graph.get_current_device_or_throw().type == "cuda"
2238+
and torch.cuda.get_device_capability()[0] >= 9
2239+
and config.assume_aligned_inputs
2240+
)
2241+
or V.graph.get_current_device_or_throw().type == "xpu"
22392242
)
2240-
or V.graph.get_current_device_or_throw().type == "xpu"
22412243
and config.triton.use_tensor_descriptor
22422244
and has_triton_stable_tma_api()
22432245
# For CUDA The base ptr needs to be aligned

0 commit comments

Comments
 (0)