Skip to content

Commit 0f6ce45

Browse files
sijiacpytorchmergebot
authored andcommitted
[Inductor] handle AMD special launch options (#124146)
Summary: `matrix_instr_nonkdim` and `waves_per_eu` are AMD specific launch configs that can't be treated as fn input args Test Plan: HIP_VISIBLE_DEVICES=7 numactl --cpunodebind=1 --membind=1 buck2 run mode/{opt,amd-gpu} -c fbcode.triton_backend=amd -c fbcode.enable_gpu_sections=true -c fbcode.rocm_arch=mi300 //hammer/modules/sequential/encoders/tests:hstu_bench -- --torch-compile=True the E2E works well on the magic model Differential Revision: D56165438 Pull Request resolved: #124146 Approved by: https://github.com/aakhundov
1 parent 4dc1608 commit 0f6ce45

1 file changed

Lines changed: 14 additions & 0 deletions

File tree

torch/_inductor/triton_heuristics.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,13 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
300300
"""Ahead of time compile a given autotuner config."""
301301
compile_meta = copy.deepcopy(self.triton_meta)
302302
for k, v in cfg.kwargs.items():
303+
if torch.version.hip is not None:
304+
if k == "matrix_instr_nonkdim":
305+
compile_meta["matrix_instr_nonkdim"] = v
306+
continue
307+
if k == "waves_per_eu":
308+
compile_meta["waves_per_eu"] = v
309+
continue
303310
compile_meta["constants"][self.fn.arg_names.index(k)] = v
304311
compile_meta["num_warps"] = cfg.num_warps
305312
compile_meta["num_stages"] = cfg.num_stages
@@ -340,6 +347,13 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
340347
"num_stages": compile_meta["num_stages"],
341348
"debug": compile_meta["debug"],
342349
}
350+
if torch.version.hip is not None:
351+
if "waves_per_eu" in compile_meta:
352+
options["waves_per_eu"] = compile_meta["waves_per_eu"]
353+
if "matrix_instr_nonkdim" in compile_meta:
354+
options["matrix_instr_nonkdim"] = compile_meta[
355+
"matrix_instr_nonkdim"
356+
]
343357
compile_kwargs = {
344358
"target": target,
345359
"options": options,

0 commit comments

Comments
 (0)