Skip to content

frexp operator frame tracing error with dynamic None #120511

@jthakurH

Description

@jthakurH

🐛 Describe the bug

frexp operator with out variant causes frame tracing error when run with dynamic=None.
If we disable dynamic shape i.e. dynamic=False for torch.compile then it works fine
use below code to reproduce the error

Error logs

Traceback (most recent call last):
  File "/tmp/sgn.py", line 25, in <module>
    res = compiled_model(params)
  File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1523, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1523, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/sgn.py", line 10, in forward
    result = self.op(**kwarg)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 353, in _convert_frame_assert
    if not has_tensor_in_frame(frame):
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1610, in inner_fn
    return fn(*args, **kwargs)
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 235, in has_tensor_in_frame
    if has_tensor(value):
  File "/tmp/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 224, in has_tensor
    seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields)
AttributeError: 'torch.return_types.frexp_out' object has no attribute '_fields'. Did you mean: 'n_fields'?

Minified repro

import torch


class NnWrapper(torch.nn.Module):
    def __init__(self, op):
        super().__init__()
        self.op = op

    def forward(self, kwarg):
        result = self.op(**kwarg)
        return result


model = NnWrapper(torch.frexp)
compiled_model = torch.compile(model)

shapes = [[1, 1, 2, 3], [1, 1, 3, 3]]
for shape in shapes:
    ifm = torch.randn(shape, dtype=torch.bfloat16)
    out = [
        torch.randn(shape, dtype=torch.bfloat16),
        torch.randint(low=-50, high=50, size=shape, dtype=torch.int32),
    ]
    params = {"input": ifm, "out": out}
    res = compiled_model(params)
    print(f"res: {res}")

Versions

[pip3] numpy==1.26.4
[pip3] torch==2.2.0a0
[pip3] torchaudio==2.2.0
[pip3] torchdata==0.7.1
[pip3] torchmetrics==1.2.1
[pip3] torchtext==0.17.0
[pip3] torchvision==0.17.0

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions