Need to carefully update partitioning logic based on broadcast. Right now we can only fuse broadcasts that result in a single output shape. Consider the python multiple output tests:
def test_broadcasting_multiple_output_shape(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP)
def test_broadcasting_multiple_output(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP)
From the perspective of the fuser both of these tests result in outputs that we can't verify to be constant sizes. Therefore it is hard to create a generalized parallelization strategy. We need to carefully prevent fusion on this type of graph. Right now with our scheduling logic this will result (correctly so) in a computeAt error.
Need to carefully update partitioning logic based on broadcast. Right now we can only fuse broadcasts that result in a single output shape. Consider the python multiple output tests:
From the perspective of the fuser both of these tests result in outputs that we can't verify to be constant sizes. Therefore it is hard to create a generalized parallelization strategy. We need to carefully prevent fusion on this type of graph. Right now with our scheduling logic this will result (correctly so) in a computeAt error.