🐛 Describe the bug
Found this bug when integrate dynamo with torchxla for resnet model. If we move the model and its inputs to XLA device before running dynamo, we would hit this bug. Check the minimal repro below.
cc @jansel @wconstab @JackCaoG
Error logs
File "/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
return fn(*args, **kwargs)
File "/pytorch/torch/_dynamo/utils.py", line 92, in time_wrapper
File "/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
File "/pytorch/torch/_refs/init.py", line 45, in
from torch.fx.experimental.symbolic_shapes import sym_float, sym_int
File "/pytorch/torch/fx/experimental/symbolic_shapes.py", line 17, in
import sympy # type: ignore[import]
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/init.py", line 51, in
from .core import (sympify, SympifyError, cacheit, Basic, Atom,
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/init.py", line 4, in
from .sympify import sympify, SympifyError
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/sympify.py", line 9, in
from .compatibility import iterable
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/compatibility.py", line 11, in
from sympy.external import import_module
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/init.py", line 18, in
from sympy.external.importtools import import_module
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/importtools.py", line 4, in
from distutils.version import LooseVersion
File "", line 983, in _find_and_load
File "", line 963, in _find_and_load_unlocked
File "", line 906, in _find_spec
return fn(*args, **kwargs)
File "/pytorch/torch/_dynamo/utils.py", line 92, in time_wrapper
r = func(*args, **kwargs)
File "/pytorch/torch/_dynamo/convert_frame.py", line 356, in _convert_frame_assert
frame,
File "/pytorch/torch/_dynamo/convert_frame.py", line 402, in _compile
out_code = transform_code_object(code, transform)
File "/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/pytorch/torch/_dynamo/convert_frame.py", line 390, in transform
tracer.run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1468, in run
super().run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 352, in run
and self.step()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 322, in step
getattr(self, inst.opname)(inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 174, in wrapper
return inner_fn(self, inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 766, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 264, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/pytorch/torch/_dynamo/variables/nn_module.py", line 209, in call_function
**options,
File "/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
return fn(*args, **kwargs)
File "/pytorch/torch/_dynamo/utils.py", line 92, in time_wrapper
r = func(*args, **kwargs)
File "/pytorch/torch/_dynamo/convert_frame.py", line 356, in _convert_frame_assert
frame,
File "/pytorch/torch/_dynamo/convert_frame.py", line 402, in _compile
out_code = transform_code_object(code, transform)
File "/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/pytorch/torch/_dynamo/convert_frame.py", line 390, in transform
tracer.run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1468, in run
super().run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 352, in run
and self.step()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 322, in step
getattr(self, inst.opname)(inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 174, in wrapper
return inner_fn(self, inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 766, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 264, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/pytorch/torch/_dynamo/variables/nn_module.py", line 209, in call_function
**options,
File "/pytorch/torch/_dynamo/variables/tensor.py", line 201, in create
example_value = _get_fake_value(proxy.node, tx)
File "/pytorch/torch/_dynamo/variables/tensor.py", line 145, in _get_fake_value
raise TorchRuntimeError() from e
torch._dynamo.exc.TorchRuntimeError:
from user code:
File "myscripts/repro_maxpool.py", line 14, in forward
out = self.pool(out)
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
Minified repro
repro_maxpool.py
from torch import nn
import torch
import torch._dynamo as dynamo
import torch_xla.core.xla_model as xm
class MaxPoolModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=2)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2)
def forward(self, x):
out = self.conv(x)
out = self.pool(out)
return out
def get_random_inputs(self):
return (torch.rand(2, 3, 10, 10),)
xla_dev = xm.xla_device()
model = MaxPoolModule().to(device=xla_dev)
inputs = map(lambda x: x.to(device=xla_dev), model.get_random_inputs())
dynamo.optimize(lambda gm, _: gm)(lambda: model(*inputs))()
Command:
GPU_NUM_DEVICES=1 python repro_maxpool.py
🐛 Describe the bug
Found this bug when integrate dynamo with torchxla for resnet model. If we move the model and its inputs to XLA device before running dynamo, we would hit this bug. Check the minimal repro below.
cc @jansel @wconstab @JackCaoG
Error logs
File "/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
return fn(*args, **kwargs)
File "/pytorch/torch/_dynamo/utils.py", line 92, in time_wrapper
File "/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
File "/pytorch/torch/_refs/init.py", line 45, in
from torch.fx.experimental.symbolic_shapes import sym_float, sym_int
File "/pytorch/torch/fx/experimental/symbolic_shapes.py", line 17, in
import sympy # type: ignore[import]
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/init.py", line 51, in
from .core import (sympify, SympifyError, cacheit, Basic, Atom,
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/init.py", line 4, in
from .sympify import sympify, SympifyError
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/sympify.py", line 9, in
from .compatibility import iterable
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/compatibility.py", line 11, in
from sympy.external import import_module
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/init.py", line 18, in
from sympy.external.importtools import import_module
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/importtools.py", line 4, in
from distutils.version import LooseVersion
File "", line 983, in _find_and_load
File "", line 963, in _find_and_load_unlocked
File "", line 906, in _find_spec
return fn(*args, **kwargs)
File "/pytorch/torch/_dynamo/utils.py", line 92, in time_wrapper
r = func(*args, **kwargs)
File "/pytorch/torch/_dynamo/convert_frame.py", line 356, in _convert_frame_assert
frame,
File "/pytorch/torch/_dynamo/convert_frame.py", line 402, in _compile
out_code = transform_code_object(code, transform)
File "/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/pytorch/torch/_dynamo/convert_frame.py", line 390, in transform
tracer.run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1468, in run
super().run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 352, in run
and self.step()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 322, in step
getattr(self, inst.opname)(inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 174, in wrapper
return inner_fn(self, inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 766, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 264, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/pytorch/torch/_dynamo/variables/nn_module.py", line 209, in call_function
**options,
File "/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
return fn(*args, **kwargs)
File "/pytorch/torch/_dynamo/utils.py", line 92, in time_wrapper
r = func(*args, **kwargs)
File "/pytorch/torch/_dynamo/convert_frame.py", line 356, in _convert_frame_assert
frame,
File "/pytorch/torch/_dynamo/convert_frame.py", line 402, in _compile
out_code = transform_code_object(code, transform)
File "/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/pytorch/torch/_dynamo/convert_frame.py", line 390, in transform
tracer.run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1468, in run
super().run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 352, in run
and self.step()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 322, in step
getattr(self, inst.opname)(inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 174, in wrapper
return inner_fn(self, inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 766, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 264, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/pytorch/torch/_dynamo/variables/nn_module.py", line 209, in call_function
**options,
File "/pytorch/torch/_dynamo/variables/tensor.py", line 201, in create
example_value = _get_fake_value(proxy.node, tx)
File "/pytorch/torch/_dynamo/variables/tensor.py", line 145, in _get_fake_value
raise TorchRuntimeError() from e
torch._dynamo.exc.TorchRuntimeError:
from user code:
File "myscripts/repro_maxpool.py", line 14, in forward
out = self.pool(out)
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
Minified repro
repro_maxpool.py
Command: