Skip to content

Commit a8cbf70

Browse files
wconstabpytorchmergebot
authored andcommitted
Inductor support for aten::all_reduce (pytorch#93111)
Pull Request resolved: pytorch#93111 Approved by: https://github.com/jansel, https://github.com/wanchaol
1 parent 5d1e9fd commit a8cbf70

6 files changed

Lines changed: 404 additions & 13 deletions

File tree

.ci/pytorch/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ test_dynamo_shard() {
249249
test_inductor_distributed() {
250250
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
251251
# with if required # gpus aren't available
252-
PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include distributed/test_dynamo_distributed --verbose
252+
PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_traceable_collectives --verbose
253253
assert_git_not_dirty
254254
}
255255

.github/labeler.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
- torch/_subclasses/fake_utils.py
1717
- torch/_subclasses/meta_utils.py
1818
- test/distributed/test_dynamo_distributed.py
19+
- test/distributed/test_traceable_collectives.py
1920
- functorch/_src/partitioners.py
2021
- functorch/_src/aot_autograd.py
2122

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Owner(s): ["module: dynamo"]
2+
import functools
3+
import unittest
4+
from unittest.mock import patch
5+
import torch
6+
from torch._C import FileCheck
7+
from torch._dispatch.python import enable_python_dispatcher
8+
import torch._dynamo
9+
import torch._dynamo.test_case
10+
from torch._dynamo.utils import same
11+
from torch._dynamo.testing import CompileCounter
12+
from torch.fx.experimental.proxy_tensor import make_fx
13+
from torch.testing._internal.common_distributed import (
14+
DynamoDistributedSingleProcTestCase,
15+
DynamoDistributedMultiProcTestCase,
16+
_dynamo_dist_per_rank_init,
17+
requires_nccl,
18+
skip_if_lt_x_gpu
19+
)
20+
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
21+
from torch._inductor.utils import has_triton, run_and_get_triton_code
22+
import torch._dynamo.logging
23+
24+
# LOL if you don't remember to import this, then the op isn't registered and it hits
25+
# the no-op C++ kernel that i am forced to implement despite not using it
26+
import torch.distributed._functional_collectives
27+
28+
29+
@requires_nccl()
30+
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
31+
"""
32+
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
33+
"""
34+
def get_world_trs(self):
35+
return {
36+
"tag": "",
37+
"ranks": list(range(self.world_size)),
38+
"group_size": self.world_size,
39+
}
40+
41+
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
42+
@skip_if_lt_x_gpu(2)
43+
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
44+
@patch.object(torch._inductor.config, "compile_threads", 1)
45+
def test_allreduce_inductor(self):
46+
"""
47+
This is matmul/cat/allreduce is a pattern we aim to optimize.
48+
"""
49+
50+
def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size):
51+
x = torch.matmul(a, b)
52+
y = torch.matmul(c, d)
53+
z = torch.cat((x, y))
54+
ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size)
55+
g = torch.matmul(e, f)
56+
ar = torch.ops.aten.wait_tensor(ar)
57+
out = torch.add(ar, g.repeat(2, 1))
58+
return (out, )
59+
60+
def compile(func, example_inputs):
61+
graph = make_fx(func)(*example_inputs)
62+
return inductor_compile_fx(graph, example_inputs)
63+
64+
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
65+
66+
matmul_cat_col = functools.partial(
67+
matmul_cat_col,
68+
**self.get_world_trs(),
69+
)
70+
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6
71+
72+
# non-ideally, i seem to need to enable this at user level in order to construct a torchdispatch subclass
73+
# inside py registered collective ops
74+
with enable_python_dispatcher():
75+
eager_out = matmul_cat_col(*inputs)
76+
compiled_matmul_cat_col = compile(matmul_cat_col, inputs)
77+
inductor_out = compiled_matmul_cat_col(*inputs)
78+
assert same(eager_out, inductor_out, tol=0.001)
79+
80+
81+
@requires_nccl()
82+
class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
83+
"""
84+
Prefer single-proc test runner for basic tests as it is easier to work with.
85+
"""
86+
def get_world_trs(self, world_size=1):
87+
return {
88+
"tag": "",
89+
"ranks": list(range(world_size)),
90+
"group_size": world_size,
91+
}
92+
93+
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
94+
def test_inductor_single_op(self):
95+
torch._inductor.config.debug = True
96+
97+
def func(inp, *, tag, ranks, group_size):
98+
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
99+
ar = torch.ops.aten.wait_tensor(ar)
100+
return ar
101+
102+
inputs = torch.ones(4, 4, device="cuda")
103+
104+
with enable_python_dispatcher():
105+
compiled = torch.compile(func)
106+
out = compiled(inputs, **self.get_world_trs())
107+
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
108+
FileCheck() \
109+
.check("buf0 = empty_strided") \
110+
.check("buf0.copy_(arg0_1)") \
111+
.check("buf0_work = dist.all_reduce(buf0") \
112+
.check("buf0_work.wait()") \
113+
.check("return (buf1, )") \
114+
.run(code)
115+
correct = func(inputs, **self.get_world_trs())
116+
assert same(out, correct)
117+
118+
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
119+
def test_inductor_steal_buffer(self):
120+
"""
121+
it's ok and optimal if inductor allreduce mutates the buffer of an intermediate
122+
that isn't going to be used again
123+
"""
124+
torch._inductor.config.debug = True
125+
126+
def func(inp, *, tag, ranks, group_size):
127+
x = inp + 1
128+
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size)
129+
ar = torch.ops.aten.wait_tensor(ar)
130+
# ensure other is not incorrectly aliasing ar's buffer
131+
other = torch.ones_like(inp) + 22
132+
return ar, other
133+
134+
inputs = torch.ones(4, 4, device="cuda")
135+
136+
with enable_python_dispatcher():
137+
compiled = torch.compile(func)
138+
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
139+
FileCheck() \
140+
.check("buf1 = buf0; del buf0 # reuse") \
141+
.check_not("buf1.copy_(") \
142+
.check("buf1_work = dist.all_reduce(buf1") \
143+
.check("buf1_work.wait()") \
144+
.check("buf2 = buf1") \
145+
.check("buf3 = empty_strided") \
146+
.check("return (buf2, buf3") \
147+
.run(code)
148+
out = compiled(inputs, **self.get_world_trs())
149+
correct = func(inputs, **self.get_world_trs())
150+
assert same(out, correct)
151+
152+
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
153+
def test_inductor_doesnt_mutate_shared(self):
154+
"""
155+
make sure that an intermediate that's going to be reuse isn't mutated unless copied
156+
"""
157+
torch._inductor.config.debug = True
158+
159+
def func(inp, *, tag, ranks, group_size):
160+
x = inp + 1
161+
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size)
162+
y = x + 2
163+
ar = torch.ops.aten.wait_tensor(ar)
164+
# ensure other is not incorrectly aliasing ar's buffer
165+
other = torch.ones_like(inp) + 22
166+
return ar, y, other
167+
168+
inputs = torch.ones(4, 4, device="cuda")
169+
170+
with enable_python_dispatcher():
171+
compiled = torch.compile(func)
172+
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
173+
FileCheck() \
174+
.check("buf0 = empty_strided(") \
175+
.check("buf2 = empty_strided") \
176+
.check("triton__0.run(arg0_1, buf0, buf2") \
177+
.check_not("copy_(") \
178+
.check("buf1 = buf0; del buf0 # reuse") \
179+
.check("buf1_work = dist.all_reduce(buf1") \
180+
.check("buf1_work.wait()") \
181+
.check("buf3 = buf1") \
182+
.check("return (buf3, buf2, buf4") \
183+
.run(code)
184+
out = compiled(inputs, **self.get_world_trs())
185+
correct = func(inputs, **self.get_world_trs())
186+
assert same(out, correct)
187+
188+
def test_dynamo_trace_allreduce(self):
189+
def func(inp, *, tag, ranks, group_size):
190+
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
191+
return ar
192+
193+
inputs = torch.ones(4, 4, device="cuda")
194+
counter = CompileCounter()
195+
with enable_python_dispatcher():
196+
compiled = torch.compile(func, backend=counter)
197+
out = compiled(inputs, **self.get_world_trs())
198+
correct = func(inputs, **self.get_world_trs())
199+
assert counter.frame_count == 1
200+
assert counter.op_count == 1
201+
assert same(out, correct)
202+
203+
def test_backwards(self):
204+
"""
205+
It's probably not that common to need backwards support for collectives.
206+
207+
However, I wanted to at least see if it was possible to support it as a design goal.
208+
"""
209+
def func(inp, *, tag, ranks, group_size):
210+
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
211+
return ar
212+
213+
input = torch.ones(4, 4, device="cuda", requires_grad=True)
214+
with enable_python_dispatcher():
215+
# TODO implement backwards
216+
with self.assertRaisesRegex(RuntimeError, "derivative for aten::all_reduce is not implemented"):
217+
compiled = torch.compile(func, backend="aot_eager") # inductor bug with single-op allreduce graph
218+
out = compiled(input, **self.get_world_trs())
219+
out.sum().backward()
220+
221+
correct_input = input.clone().detach().requires_grad_()
222+
correct = func(correct_input, **self.get_world_trs())
223+
correct.sum().backward()
224+
assert same(out, correct)
225+
assert same(input.grad, correct_input.grad)
226+
227+
def test_meta(self):
228+
x = torch.rand((2, 3, 4), device="meta")
229+
out = torch.ops.aten.all_reduce(x, "sum", **self.get_world_trs())
230+
assert x.size() == out.size()
231+
232+
233+
if __name__ == "__main__":
234+
from torch._dynamo.test_case import run_tests
235+
236+
run_tests()

torch/_inductor/ir.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@
7878
Tensors backed by views add one more indirection to the IR.
7979
TensorBox -> View -> StorageBox -> Buffer
8080
In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
81-
82-
For metadata mutation (e.g. as_strided_) we swing the TensorBox pointer.
8381
"""
8482

8583

@@ -4202,3 +4200,109 @@ def debug_str(self, name="block"):
42024200
"",
42034201
code.strip().replace("def forward(", f"def {name}("),
42044202
)
4203+
4204+
4205+
class Wait(ExternKernel):
4206+
"""
4207+
Wait should not be used by itself. It should always be constructed in tandem
4208+
with a collective op that produces a work to wait on.
4209+
"""
4210+
4211+
def __init__(
4212+
self,
4213+
layout,
4214+
inputs,
4215+
constant_args=(),
4216+
):
4217+
super().__init__(None, layout, inputs, constant_args)
4218+
self.name = V.graph.register_buffer(self)
4219+
4220+
def should_allocate(self):
4221+
return False
4222+
4223+
def codegen(self, wrapper):
4224+
(input_collective,) = [t.codegen_reference() for t in self.inputs]
4225+
work = f"{input_collective}_work" # hacky way to name work objs..
4226+
wrapper.writeline(f"{work}.wait()")
4227+
4228+
# wait op still needs to produce a 'buffer' that represents the tensor output.
4229+
# this is a symbolic gesture, and it gets handled by WrapperCodegen.
4230+
# codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective')
4231+
# to a new name (`self.get_name()`) and `del`s the old name.
4232+
wrapper.writeline(f"{self.get_name()} = {input_collective}")
4233+
4234+
@classmethod
4235+
def create(cls, collective_op: "TensorBox"):
4236+
return Wait(
4237+
layout=collective_op.get_layout(),
4238+
inputs=[collective_op],
4239+
)
4240+
4241+
def get_alias_names(self):
4242+
# Signal to codegen that our output buffer isn't safe to reuse
4243+
return [self.inputs[0].codegen_reference()]
4244+
4245+
4246+
class AllReduce(ExternKernel):
4247+
def __init__(
4248+
self,
4249+
layout,
4250+
inputs,
4251+
constant_args=(),
4252+
):
4253+
super().__init__(None, layout, inputs, constant_args)
4254+
self.name = V.graph.register_buffer(self)
4255+
4256+
def should_allocate(self):
4257+
return True
4258+
4259+
@classmethod
4260+
def create(
4261+
cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int
4262+
):
4263+
x = cls.realize_input(x)
4264+
4265+
# is there a difference between literally using x.data.layout below, vs
4266+
# creating a new one that has the same properties?
4267+
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), x.get_size())
4268+
4269+
# AllReduce returns a 'work' object. But Inductor's scheduler doesn't need to know
4270+
# about that, and we just pretend for scheduling purposes that the work obj is a 1-elem tensor.
4271+
# Nobody should consume the output of AllReduce except 'Wait', which we control here.
4272+
return AllReduce(
4273+
layout=new_layout,
4274+
inputs=[x],
4275+
constant_args=[reduce_op, tag, ranks, group_size],
4276+
)
4277+
4278+
def codegen(self, wrapper):
4279+
wrapper.add_import_once("import torch.distributed as dist")
4280+
wrapper.add_import_once(
4281+
"from torch.distributed._functional_collectives import _str_to_reduce_op"
4282+
)
4283+
wrapper.add_import_once(
4284+
"from torch.distributed.distributed_c10d import _find_or_create_pg_by_ranks_and_tag"
4285+
)
4286+
4287+
# extract references to our args in string form for codegen output
4288+
(input_name,) = [t.codegen_reference() for t in self.inputs]
4289+
output_name = self.get_name()
4290+
reduce_op, tag, ranks, group_size = self.constant_args
4291+
4292+
# TODO: avoid more than one ref of the same pg (even though they are cached inside the api)
4293+
wrapper.writeline(
4294+
f"{output_name}_pg = _find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})"
4295+
)
4296+
4297+
# We must copy our input buffer sometimes, but the scheduler will help us find opportunities
4298+
# to reuse the input buffer. (This requires no other users of the input buffer.)
4299+
if not wrapper.did_reuse(self, self.inputs[0]):
4300+
wrapper.writeline(f"{output_name}.copy_({input_name})")
4301+
4302+
# At this point, output_name points to a buffer that is either
4303+
# (1) the input buffer, which we're allowed to inplace modify
4304+
# (2) a freshly allocated buffer, which we've copied the input into above
4305+
wrapper.writeline(
4306+
f"{output_name}_work = dist.all_reduce({output_name}, async_op=True,"
4307+
f" group={output_name}_pg, op=_str_to_reduce_op('{str(reduce_op)}'))"
4308+
)

0 commit comments

Comments
 (0)