Skip to content

Commit 31c0467

Browse files
int3pytorchmergebot
authored andcommitted
Add Triton CPU as an Inductor backend (#133408)
The goal is to use Inductor-generated kernels to stress test the new Triton CPU backend. Differential Revision: [D63298968](https://our.internmc.facebook.com/intern/diff/D63298968) Pull Request resolved: #133408 Approved by: https://github.com/jansel, https://github.com/blaine-rister, https://github.com/malfet
1 parent 68579ef commit 31c0467

34 files changed

Lines changed: 452 additions & 255 deletions

.ci/docker/build.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ case "$image" in
379379
GCC_VERSION=11
380380
CONDA_CMAKE=yes
381381
HALIDE=yes
382+
TRITON=yes
382383
;;
383384
pytorch-linux-focal-linter)
384385
# TODO: Use 3.9 here because of this issue https://github.com/python/mypy/issues/13627.

test/distributed/_composable/fsdp/test_fully_shard_compile.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
ModelArgs,
3131
Transformer,
3232
)
33-
from torch.utils._triton import has_triton
33+
from torch.testing._internal.inductor_utils import HAS_GPU
3434

3535

3636
log = logging.getLogger(__name__)
@@ -48,7 +48,7 @@ def _is_fallback_op_in_snodes(snodes, op):
4848

4949

5050
class TestFullyShardCompileCompute(FSDPTest):
51-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
51+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
5252
@skip_if_lt_x_gpu(2)
5353
def test_disable_compiling_hooks(self):
5454
self.run_subtests(
@@ -529,14 +529,14 @@ def input_creation_fn():
529529
return model_init_fn, input_creation_fn
530530

531531
@skipIfRocm
532-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
532+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
533533
def test_simple_mlp_fullgraph_backend_aot_eager(self):
534534
self._test_traceable_fsdp(
535535
*self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True
536536
)
537537

538538
@skipIfRocm
539-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
539+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
540540
def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
541541
self._test_traceable_fsdp(
542542
*self._create_simple_mlp_factory_fns(),
@@ -545,7 +545,7 @@ def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
545545
)
546546

547547
@skipIfRocm
548-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
548+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
549549
def test_simple_mlp_fullgraph_backend_inductor(self):
550550
self._test_traceable_fsdp(
551551
*self._create_simple_mlp_factory_fns(), "inductor", fullgraph=True
@@ -613,7 +613,7 @@ def input_creation_fn():
613613
return model_init_fn, input_creation_fn
614614

615615
@skipIfRocm
616-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
616+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
617617
def test_nested_fully_shard_backend_aot_eager(self):
618618
for fullgraph in [True, False]:
619619
self._test_traceable_fsdp(
@@ -623,7 +623,7 @@ def test_nested_fully_shard_backend_aot_eager(self):
623623
)
624624

625625
@skipIfRocm
626-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
626+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
627627
def test_nested_fully_shard_backend_aot_eager_decomp_partition(self):
628628
for fullgraph in [True, False]:
629629
self._test_traceable_fsdp(
@@ -633,7 +633,7 @@ def test_nested_fully_shard_backend_aot_eager_decomp_partition(self):
633633
)
634634

635635
@skipIfRocm
636-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
636+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
637637
def test_nested_fully_shard_backend_inductor_fullgraph_True(self):
638638
for fullgraph in [True]:
639639
with self._reinplace_all_gather_with_optional_checks(
@@ -729,7 +729,7 @@ def test_nested_fully_shard_backend_inductor_fullgraph_True(self):
729729
file_check.run(bwd_code)
730730

731731
@skipIfRocm
732-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
732+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
733733
def test_nested_fully_shard_backend_inductor_fullgraph_False(self):
734734
_, triton_codes = run_and_get_code(
735735
lambda: self._test_traceable_fsdp(
@@ -806,7 +806,7 @@ def _sdpa_with_graph_break(*args, **kwargs):
806806
return contextlib.nullcontext()
807807

808808
@skipIfRocm
809-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
809+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
810810
def test_transformer_backend_aot_eager(self):
811811
for fullgraph, all_requires_grad in itertools.product(
812812
[True, False], [True, False]
@@ -823,7 +823,7 @@ def test_transformer_backend_aot_eager(self):
823823
)
824824

825825
@skipIfRocm
826-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
826+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
827827
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
828828
@torch._inductor.config.patch(fallback_random=True)
829829
def test_transformer_backend_aot_eager_decomp_partition(self):
@@ -840,7 +840,7 @@ def test_transformer_backend_aot_eager_decomp_partition(self):
840840
)
841841

842842
@skipIfRocm
843-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
843+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
844844
# TODO: native_dropout causes CUDA IMA error, need to figure out why
845845
@torch._inductor.config.patch(fallback_random=True)
846846
def test_transformer_backend_inductor_fullgraph_True(self):
@@ -943,7 +943,7 @@ def test_transformer_backend_inductor_fullgraph_True(self):
943943
file_check.run(bwd_code)
944944

945945
@skipIfRocm
946-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
946+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
947947
# TODO: native_dropout causes CUDA IMA error, need to figure out why
948948
@torch._inductor.config.patch(fallback_random=True)
949949
def test_transformer_backend_inductor_fullgraph_False(self):

test/distributed/_composable/fully_shard/test_fully_shard_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TransformerWithSharedParams,
1919
)
2020
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
21-
from torch.utils._triton import has_triton
21+
from torch.testing._internal.inductor_utils import HAS_GPU
2222

2323

2424
if not dist.is_available():
@@ -38,7 +38,7 @@ class TestCompile(FSDPTest):
3838
def world_size(self) -> int:
3939
return torch.cuda.device_count()
4040

41-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
41+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
4242
@skip_if_lt_x_gpu(2)
4343
def test_compile(self):
4444
self.run_subtests(

test/distributed/_composable/test_replicate_with_compiler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from torch.testing._internal.common_utils import run_tests, skipIfRocm
3535
from torch.testing._internal.distributed.fake_pg import FakeStore
36-
from torch.utils._triton import has_triton
36+
from torch.testing._internal.inductor_utils import HAS_GPU
3737
from torch.utils.checkpoint import checkpoint
3838

3939

@@ -216,21 +216,21 @@ def test_compile_cpu_no_sync(self):
216216
]
217217
self._test_compile(use_gpu=False, no_sync=True)
218218

219-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
219+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
220220
@skip_if_rocm_multiprocess
221221
@skip_if_lt_x_gpu(2)
222222
@torch._inductor.config.patch(reorder_for_locality=False)
223223
def test_compile_gpu(self):
224224
self._test_compile(use_gpu=True, no_sync=False, checkpoint=False)
225225

226-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
226+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
227227
@skip_if_rocm_multiprocess
228228
@skip_if_lt_x_gpu(2)
229229
@torch._inductor.config.patch(reorder_for_locality=False)
230230
def test_compile_gpu_ac(self):
231231
self._test_compile(use_gpu=True, no_sync=False, checkpoint=True)
232232

233-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
233+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
234234
@skip_if_rocm_multiprocess
235235
@skip_if_lt_x_gpu(2)
236236
def test_compile_bf16(self):
@@ -244,7 +244,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
244244

245245
self._test_compile(use_gpu=True, no_sync=False, setup_func=setup)
246246

247-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
247+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
248248
@skip_if_rocm_multiprocess
249249
@skip_if_lt_x_gpu(2)
250250
def test_compile_fp16(self):
@@ -261,7 +261,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
261261
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
262262
)
263263

264-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
264+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
265265
@skip_if_rocm_multiprocess
266266
@skip_if_lt_x_gpu(2)
267267
def test_compile_backward_only(self):
@@ -385,7 +385,7 @@ def setUp(self):
385385
def tearDown(self):
386386
dist.destroy_process_group()
387387

388-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
388+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
389389
@skipIfRocm
390390
def test_ddp_tp(self):
391391
ref_model = Net()

test/distributed/_tensor/test_dtensor_compile.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
with_comms,
4747
)
4848
from torch.testing._internal.distributed.fake_pg import FakeStore
49-
from torch.utils._triton import has_triton
49+
from torch.testing._internal.inductor_utils import HAS_GPU
5050
from torch.utils.checkpoint import checkpoint
5151

5252

@@ -439,7 +439,7 @@ def fn(x):
439439
tmp_dt._local_tensor.stride(), tmp_dt_fake._local_tensor.stride()
440440
)
441441

442-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
442+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
443443
def test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent(self):
444444
# Partial -> Shard on an unbalanced tensor results in:
445445
# - A contiguous DTensor
@@ -515,7 +515,7 @@ def fw_hook(module, inp, out):
515515
out_test = opt_mod(dt)
516516
self.assertEqual(out_ref, out_test)
517517

518-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
518+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
519519
def test_dtensor_different_gradient_placement(self):
520520
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
521521

@@ -647,7 +647,7 @@ def forward(self, primals_1):
647647
return (sin_1, primals_1, wait_tensor)""",
648648
)
649649

650-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
650+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
651651
def test_dtensor_partial_placement_graph_output(self):
652652
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
653653

@@ -665,7 +665,7 @@ def fn(x):
665665
out_dt = torch.matmul(tmp_dt, y_dt)
666666
out_dt.sum().backward()
667667

668-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
668+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
669669
@skip_if_lt_x_gpu(1)
670670
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
671671
@patch.object(torch._inductor.config, "compile_threads", 1)

test/distributed/fsdp/test_fsdp_use_orig_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
TEST_WITH_DEV_DBG_ASAN,
4444
TestCase,
4545
)
46-
from torch.utils._triton import has_triton
46+
from torch.testing._internal.inductor_utils import HAS_GPU
4747

4848

4949
if not dist.is_available():
@@ -218,7 +218,7 @@ def _get_sharding_strategy_from_str(
218218
raise ValueError(f"Invalid string: {sharding_strategy_str}")
219219
return sharding_strategy
220220

221-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
221+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
222222
@skip_if_lt_x_gpu(2)
223223
def test_fsdp_compile(self):
224224
self.run_subtests(

test/distributed/tensor/parallel/test_micro_pipeline_tp.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
3939
from torch.testing._internal.distributed.fake_pg import FakeStore
40-
from torch.utils._triton import has_triton
40+
from torch.testing._internal.inductor_utils import HAS_GPU
4141

4242

4343
def _make_post_grad_fx(f, *inps):
@@ -78,7 +78,7 @@ def setUp(self):
7878
def tearDown(self):
7979
dist.destroy_process_group()
8080

81-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
81+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
8282
@fresh_inductor_cache()
8383
def test_find_all_gather_patterns(self):
8484
group = dist.group.WORLD
@@ -129,7 +129,7 @@ def func(inp: torch.Tensor) -> torch.Tensor:
129129
torch.ops.aten.view.dtype,
130130
)
131131

132-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
132+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
133133
@fresh_inductor_cache()
134134
def test_find_reduce_scatter_patterns(self):
135135
group = dist.group.WORLD
@@ -168,7 +168,7 @@ def func(inp: torch.Tensor) -> torch.Tensor:
168168
self.assertEqual(reduce_scatters[1].reduce_op, "avg")
169169
self.assertEqual(reduce_scatters[1].scatter_dim, 1)
170170

171-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
171+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
172172
@fresh_inductor_cache()
173173
def test_get_unexposed_collectives(self):
174174
group = dist.group.WORLD
@@ -193,7 +193,7 @@ def func(inp: torch.Tensor) -> torch.Tensor:
193193
["all_gather_into_tensor", "reduce_scatter_tensor"],
194194
)
195195

196-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
196+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
197197
@parametrize("A_dims", [2, 3])
198198
@parametrize("gather_dim", [0, 1, 2])
199199
@fresh_inductor_cache()
@@ -231,7 +231,7 @@ def func(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
231231
self.assertNotIn("all_gather_into_tensor", code)
232232

233233
@runOnRocmArch(MI300_ARCH)
234-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
234+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
235235
@parametrize("A_dims", [2, 3])
236236
@parametrize("gather_dim", [0, 1, 2])
237237
@fresh_inductor_cache()
@@ -299,7 +299,7 @@ def func(
299299
self.assertIn("fused_all_gather_scaled_matmul", code)
300300
self.assertNotIn("all_gather_into_tensor", code)
301301

302-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
302+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
303303
@parametrize("A_dims", [2, 3])
304304
@parametrize("scatter_dim", [0, 1, 2])
305305
@fresh_inductor_cache()
@@ -328,7 +328,7 @@ def func(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
328328
self.assertNotIn("reduce_scatter_tensor", code)
329329

330330
@runOnRocmArch(MI300_ARCH)
331-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
331+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
332332
@parametrize("A_dims", [2, 3])
333333
@parametrize("scatter_dim", [0, 1, 2])
334334
@fresh_inductor_cache()
@@ -381,7 +381,7 @@ def func(
381381
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
382382
self.assertNotIn("reduce_scatter_tensor", code)
383383

384-
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
384+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
385385
@parametrize("shard_dim", [0, 1])
386386
@fresh_inductor_cache()
387387
def test_dtensor_seq_par(self, shard_dim: int):

0 commit comments

Comments
 (0)