For outer operator, whenever vec2 shape is 1, we see that graph break happening. This issue is not seen if vec2 is non 1 shape.
If we compile with dynamic=False in torch.compile then it works fine.
res: torch.Size([2, 1])
Traceback (most recent call last):
File "/tmp/outer.py", line 20, in <module>
res = compile_fn(
File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1523, in _call_impl
return forward_call(*args, **kwargs)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1523, in _call_impl
return forward_call(*args, **kwargs)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
compiled_product = _compile(
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
out_code = transform_code_object(code, transform)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
return fn(*args, **kwargs)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
tracer.run()
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
super().run()
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py", line 585, in call_function
unimplemented("out variants with resizing on graph inputs")
File "/tmp/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: out variants with resizing on graph inputs
import torch
from torch import nn
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs):
return torch.outer(**inputs)
compile_fn = torch.compile(CustomModel(), fullgraph=True)
shapes = [(2,1), (6,1), (4,1)]
for shape in shapes:
vec1, vec2 = shape
input_tensor1 = torch.randn(vec1)
input_tensor2 = torch.randn(vec2)
out_tensor = torch.empty(shape)
res = compile_fn(
{"input": input_tensor1, "vec2": input_tensor2, "out": out_tensor}
)
print(f"res: {res.shape}")
print("Test passed!")
🐛 Describe the bug
For outer operator, whenever vec2 shape is 1, we see that graph break happening. This issue is not seen if vec2 is non 1 shape.
If we compile with dynamic=False in torch.compile then it works fine.
Error logs
Error:
Minified repro
Versions
[pip3] numpy==1.26.4
[pip3] torch==2.2.0a0
[pip3] torchaudio==2.2.0
[pip3] torchdata==0.7.1
[pip3] torchmetrics==1.2.1
[pip3] torchtext==0.17.0
[pip3] torchvision==0.17.0
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519