|
| 1 | +# Owner(s): ["module: dynamo"] |
| 2 | +import functools |
| 3 | +import unittest |
| 4 | +from unittest.mock import patch |
| 5 | +import torch |
| 6 | +from torch._C import FileCheck |
| 7 | +from torch._dispatch.python import enable_python_dispatcher |
| 8 | +import torch._dynamo |
| 9 | +import torch._dynamo.test_case |
| 10 | +from torch._dynamo.utils import same |
| 11 | +from torch._dynamo.testing import CompileCounter |
| 12 | +from torch.fx.experimental.proxy_tensor import make_fx |
| 13 | +from torch.testing._internal.common_distributed import ( |
| 14 | + DynamoDistributedSingleProcTestCase, |
| 15 | + DynamoDistributedMultiProcTestCase, |
| 16 | + _dynamo_dist_per_rank_init, |
| 17 | + requires_nccl, |
| 18 | + skip_if_lt_x_gpu |
| 19 | +) |
| 20 | +from torch._inductor.compile_fx import compile_fx as inductor_compile_fx |
| 21 | +from torch._inductor.utils import has_triton, run_and_get_triton_code |
| 22 | +import torch._dynamo.logging |
| 23 | + |
| 24 | +# LOL if you don't remember to import this, then the op isn't registered and it hits |
| 25 | +# the no-op C++ kernel that i am forced to implement despite not using it |
| 26 | +import torch.distributed._functional_collectives |
| 27 | + |
| 28 | + |
| 29 | +@requires_nccl() |
| 30 | +class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): |
| 31 | + """ |
| 32 | + Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under |
| 33 | + """ |
| 34 | + def get_world_trs(self): |
| 35 | + return { |
| 36 | + "tag": "", |
| 37 | + "ranks": list(range(self.world_size)), |
| 38 | + "group_size": self.world_size, |
| 39 | + } |
| 40 | + |
| 41 | + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| 42 | + @skip_if_lt_x_gpu(2) |
| 43 | + # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor |
| 44 | + @patch.object(torch._inductor.config, "compile_threads", 1) |
| 45 | + def test_allreduce_inductor(self): |
| 46 | + """ |
| 47 | + This is matmul/cat/allreduce is a pattern we aim to optimize. |
| 48 | + """ |
| 49 | + |
| 50 | + def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size): |
| 51 | + x = torch.matmul(a, b) |
| 52 | + y = torch.matmul(c, d) |
| 53 | + z = torch.cat((x, y)) |
| 54 | + ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size) |
| 55 | + g = torch.matmul(e, f) |
| 56 | + ar = torch.ops.aten.wait_tensor(ar) |
| 57 | + out = torch.add(ar, g.repeat(2, 1)) |
| 58 | + return (out, ) |
| 59 | + |
| 60 | + def compile(func, example_inputs): |
| 61 | + graph = make_fx(func)(*example_inputs) |
| 62 | + return inductor_compile_fx(graph, example_inputs) |
| 63 | + |
| 64 | + with _dynamo_dist_per_rank_init(self.rank, self.world_size): |
| 65 | + |
| 66 | + matmul_cat_col = functools.partial( |
| 67 | + matmul_cat_col, |
| 68 | + **self.get_world_trs(), |
| 69 | + ) |
| 70 | + inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6 |
| 71 | + |
| 72 | + # non-ideally, i seem to need to enable this at user level in order to construct a torchdispatch subclass |
| 73 | + # inside py registered collective ops |
| 74 | + with enable_python_dispatcher(): |
| 75 | + eager_out = matmul_cat_col(*inputs) |
| 76 | + compiled_matmul_cat_col = compile(matmul_cat_col, inputs) |
| 77 | + inductor_out = compiled_matmul_cat_col(*inputs) |
| 78 | + assert same(eager_out, inductor_out, tol=0.001) |
| 79 | + |
| 80 | + |
| 81 | +@requires_nccl() |
| 82 | +class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): |
| 83 | + """ |
| 84 | + Prefer single-proc test runner for basic tests as it is easier to work with. |
| 85 | + """ |
| 86 | + def get_world_trs(self, world_size=1): |
| 87 | + return { |
| 88 | + "tag": "", |
| 89 | + "ranks": list(range(world_size)), |
| 90 | + "group_size": world_size, |
| 91 | + } |
| 92 | + |
| 93 | + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| 94 | + def test_inductor_single_op(self): |
| 95 | + torch._inductor.config.debug = True |
| 96 | + |
| 97 | + def func(inp, *, tag, ranks, group_size): |
| 98 | + ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size) |
| 99 | + ar = torch.ops.aten.wait_tensor(ar) |
| 100 | + return ar |
| 101 | + |
| 102 | + inputs = torch.ones(4, 4, device="cuda") |
| 103 | + |
| 104 | + with enable_python_dispatcher(): |
| 105 | + compiled = torch.compile(func) |
| 106 | + out = compiled(inputs, **self.get_world_trs()) |
| 107 | + code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) |
| 108 | + FileCheck() \ |
| 109 | + .check("buf0 = empty_strided") \ |
| 110 | + .check("buf0.copy_(arg0_1)") \ |
| 111 | + .check("buf0_work = dist.all_reduce(buf0") \ |
| 112 | + .check("buf0_work.wait()") \ |
| 113 | + .check("return (buf1, )") \ |
| 114 | + .run(code) |
| 115 | + correct = func(inputs, **self.get_world_trs()) |
| 116 | + assert same(out, correct) |
| 117 | + |
| 118 | + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| 119 | + def test_inductor_steal_buffer(self): |
| 120 | + """ |
| 121 | + it's ok and optimal if inductor allreduce mutates the buffer of an intermediate |
| 122 | + that isn't going to be used again |
| 123 | + """ |
| 124 | + torch._inductor.config.debug = True |
| 125 | + |
| 126 | + def func(inp, *, tag, ranks, group_size): |
| 127 | + x = inp + 1 |
| 128 | + ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size) |
| 129 | + ar = torch.ops.aten.wait_tensor(ar) |
| 130 | + # ensure other is not incorrectly aliasing ar's buffer |
| 131 | + other = torch.ones_like(inp) + 22 |
| 132 | + return ar, other |
| 133 | + |
| 134 | + inputs = torch.ones(4, 4, device="cuda") |
| 135 | + |
| 136 | + with enable_python_dispatcher(): |
| 137 | + compiled = torch.compile(func) |
| 138 | + code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) |
| 139 | + FileCheck() \ |
| 140 | + .check("buf1 = buf0; del buf0 # reuse") \ |
| 141 | + .check_not("buf1.copy_(") \ |
| 142 | + .check("buf1_work = dist.all_reduce(buf1") \ |
| 143 | + .check("buf1_work.wait()") \ |
| 144 | + .check("buf2 = buf1") \ |
| 145 | + .check("buf3 = empty_strided") \ |
| 146 | + .check("return (buf2, buf3") \ |
| 147 | + .run(code) |
| 148 | + out = compiled(inputs, **self.get_world_trs()) |
| 149 | + correct = func(inputs, **self.get_world_trs()) |
| 150 | + assert same(out, correct) |
| 151 | + |
| 152 | + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| 153 | + def test_inductor_doesnt_mutate_shared(self): |
| 154 | + """ |
| 155 | + make sure that an intermediate that's going to be reuse isn't mutated unless copied |
| 156 | + """ |
| 157 | + torch._inductor.config.debug = True |
| 158 | + |
| 159 | + def func(inp, *, tag, ranks, group_size): |
| 160 | + x = inp + 1 |
| 161 | + ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size) |
| 162 | + y = x + 2 |
| 163 | + ar = torch.ops.aten.wait_tensor(ar) |
| 164 | + # ensure other is not incorrectly aliasing ar's buffer |
| 165 | + other = torch.ones_like(inp) + 22 |
| 166 | + return ar, y, other |
| 167 | + |
| 168 | + inputs = torch.ones(4, 4, device="cuda") |
| 169 | + |
| 170 | + with enable_python_dispatcher(): |
| 171 | + compiled = torch.compile(func) |
| 172 | + code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) |
| 173 | + FileCheck() \ |
| 174 | + .check("buf0 = empty_strided(") \ |
| 175 | + .check("buf2 = empty_strided") \ |
| 176 | + .check("triton__0.run(arg0_1, buf0, buf2") \ |
| 177 | + .check_not("copy_(") \ |
| 178 | + .check("buf1 = buf0; del buf0 # reuse") \ |
| 179 | + .check("buf1_work = dist.all_reduce(buf1") \ |
| 180 | + .check("buf1_work.wait()") \ |
| 181 | + .check("buf3 = buf1") \ |
| 182 | + .check("return (buf3, buf2, buf4") \ |
| 183 | + .run(code) |
| 184 | + out = compiled(inputs, **self.get_world_trs()) |
| 185 | + correct = func(inputs, **self.get_world_trs()) |
| 186 | + assert same(out, correct) |
| 187 | + |
| 188 | + def test_dynamo_trace_allreduce(self): |
| 189 | + def func(inp, *, tag, ranks, group_size): |
| 190 | + ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size) |
| 191 | + return ar |
| 192 | + |
| 193 | + inputs = torch.ones(4, 4, device="cuda") |
| 194 | + counter = CompileCounter() |
| 195 | + with enable_python_dispatcher(): |
| 196 | + compiled = torch.compile(func, backend=counter) |
| 197 | + out = compiled(inputs, **self.get_world_trs()) |
| 198 | + correct = func(inputs, **self.get_world_trs()) |
| 199 | + assert counter.frame_count == 1 |
| 200 | + assert counter.op_count == 1 |
| 201 | + assert same(out, correct) |
| 202 | + |
| 203 | + def test_backwards(self): |
| 204 | + """ |
| 205 | + It's probably not that common to need backwards support for collectives. |
| 206 | +
|
| 207 | + However, I wanted to at least see if it was possible to support it as a design goal. |
| 208 | + """ |
| 209 | + def func(inp, *, tag, ranks, group_size): |
| 210 | + ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size) |
| 211 | + return ar |
| 212 | + |
| 213 | + input = torch.ones(4, 4, device="cuda", requires_grad=True) |
| 214 | + with enable_python_dispatcher(): |
| 215 | + # TODO implement backwards |
| 216 | + with self.assertRaisesRegex(RuntimeError, "derivative for aten::all_reduce is not implemented"): |
| 217 | + compiled = torch.compile(func, backend="aot_eager") # inductor bug with single-op allreduce graph |
| 218 | + out = compiled(input, **self.get_world_trs()) |
| 219 | + out.sum().backward() |
| 220 | + |
| 221 | + correct_input = input.clone().detach().requires_grad_() |
| 222 | + correct = func(correct_input, **self.get_world_trs()) |
| 223 | + correct.sum().backward() |
| 224 | + assert same(out, correct) |
| 225 | + assert same(input.grad, correct_input.grad) |
| 226 | + |
| 227 | + def test_meta(self): |
| 228 | + x = torch.rand((2, 3, 4), device="meta") |
| 229 | + out = torch.ops.aten.all_reduce(x, "sum", **self.get_world_trs()) |
| 230 | + assert x.size() == out.size() |
| 231 | + |
| 232 | + |
| 233 | +if __name__ == "__main__": |
| 234 | + from torch._dynamo.test_case import run_tests |
| 235 | + |
| 236 | + run_tests() |
0 commit comments