Skip to content

[compile]outer operator with out= causes graph break in dynamic shape compilation when vec2 shape is 1 #120482

@jthakurH

Description

@jthakurH

🐛 Describe the bug

For outer operator, whenever vec2 shape is 1, we see that graph break happening. This issue is not seen if vec2 is non 1 shape.
If we compile with dynamic=False in torch.compile then it works fine.

Error logs

Error:

res: torch.Size([2, 1])
Traceback (most recent call last):
  File "/tmp/outer.py", line 20, in <module>
    res = compile_fn(
  File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1523, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1523, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py", line 585, in call_function
    unimplemented("out variants with resizing on graph inputs")
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: out variants with resizing on graph inputs

Minified repro

import torch
from torch import nn


class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        return torch.outer(**inputs)

compile_fn = torch.compile(CustomModel(), fullgraph=True)

shapes = [(2,1), (6,1), (4,1)]
for shape in shapes:
    vec1, vec2 = shape
    input_tensor1 = torch.randn(vec1)
    input_tensor2 = torch.randn(vec2)
    out_tensor = torch.empty(shape)
    res = compile_fn(
        {"input": input_tensor1, "vec2": input_tensor2, "out": out_tensor}
    )
    print(f"res: {res.shape}")
print("Test passed!")

Versions

[pip3] numpy==1.26.4
[pip3] torch==2.2.0a0
[pip3] torchaudio==2.2.0
[pip3] torchdata==0.7.1
[pip3] torchmetrics==1.2.1
[pip3] torchtext==0.17.0
[pip3] torchvision==0.17.0

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions