Skip to content

Autocast fails on dynamo. #6511

@ysiraichi

Description

@ysiraichi

🐛 Bug

@torch.compile(backend="openxla")
def foo(x):
    with torch_xla.amp.autocast(xm.xla_device(), dtype=torch.bfloat16):
        y = x * 5
    return y
x = torch.rand(5, device=xm.xla_device())
foo(x)
Traceback (most recent call last):
  File "examples/autocast.py", line 12, in <module>
    foo(x)
  File "torch/_dynamo/eval_frame.py", line 454, in _fn
    return fn(*args, **kwargs)
  File "examples/autocast.py", line 7, in foo
    with autocast(xm.xla_device(), dtype=torch.bfloat16):
  File "examples/autocast.py", line 7, in torch_dynamo_resume_in_foo_at_7
    with autocast(xm.xla_device(), dtype=torch.bfloat16):
  File "torch/_dynamo/convert_frame.py", line 904, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "torch/_dynamo/convert_frame.py", line 769, in _convert_frame
    result = inner_convert(
  File "torch/_dynamo/convert_frame.py", line 398, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "torch/_dynamo/convert_frame.py", line 696, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "torch/_dynamo/convert_frame.py", line 669, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "torch/_dynamo/utils.py", line 249, in time_wrapper
    r = func(*args, **kwargs)
  File "torch/_dynamo/convert_frame.py", line 542, in compile_inner
    out_code = transform_code_object(code, transform)
  File "torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "torch/_dynamo/convert_frame.py", line 163, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/convert_frame.py", line 507, in transform
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 2130, in run
    super().run()
  File "torch/_dynamo/symbolic_convert.py", line 793, in run
    and self.step()
  File "torch/_dynamo/symbolic_convert.py", line 756, in step
    getattr(self, inst.opname)(inst)
  File "torch/_dynamo/symbolic_convert.py", line 1305, in STORE_ATTR
    if isinstance(obj, NNModuleVariable):
  File "torch/_dynamo/variables/base.py", line 135, in __instancecheck__
    instance = instance.realize()
  File "torch/_dynamo/variables/lazy.py", line 58, in realize
    self._cache.realize(self.parents_tracker)
  File "torch/_dynamo/variables/lazy.py", line 24, in realize
    self.vt = VariableBuilder(tx, self.source)(self.value)
  File "torch/_dynamo/variables/builder.py", line 247, in __call__
    vt = self._wrap(value)
  File "torch/_dynamo/variables/builder.py", line 702, in _wrap
    value.device,
torch._dynamo.exc.InternalTorchDynamoError: 'autocast' object has no attribute 'device'

from user code:
   File "xla/torch_xla/amp/autocast_mode.py", line 26, in __init__
    self._enabled = enabled

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Environment

  • Reproducible on XLA backend [CPU/TPU]: CUDA
  • torch_xla version: 157e06e

cc @miladm @JackCaoG

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions