Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Dynamo can not optimize a model with MaxPool2d on XLA devices #1837

@shunting314

Description

@shunting314

🐛 Describe the bug

Found this bug when integrate dynamo with torchxla for resnet model. If we move the model and its inputs to XLA device before running dynamo, we would hit this bug. Check the minimal repro below.

cc @jansel @wconstab @JackCaoG

Error logs

File "/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
return fn(*args, **kwargs)
File "/pytorch/torch/_dynamo/utils.py", line 92, in time_wrapper
File "/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
File "/pytorch/torch/_refs/init.py", line 45, in
from torch.fx.experimental.symbolic_shapes import sym_float, sym_int
File "/pytorch/torch/fx/experimental/symbolic_shapes.py", line 17, in
import sympy # type: ignore[import]
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/init.py", line 51, in
from .core import (sympify, SympifyError, cacheit, Basic, Atom,
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/init.py", line 4, in
from .sympify import sympify, SympifyError
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/sympify.py", line 9, in
from .compatibility import iterable
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/compatibility.py", line 11, in
from sympy.external import import_module
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/init.py", line 18, in
from sympy.external.importtools import import_module
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/importtools.py", line 4, in
from distutils.version import LooseVersion
File "", line 983, in _find_and_load
File "", line 963, in _find_and_load_unlocked
File "", line 906, in _find_spec
return fn(*args, **kwargs)
File "/pytorch/torch/_dynamo/utils.py", line 92, in time_wrapper
r = func(*args, **kwargs)
File "/pytorch/torch/_dynamo/convert_frame.py", line 356, in _convert_frame_assert
frame,
File "/pytorch/torch/_dynamo/convert_frame.py", line 402, in _compile
out_code = transform_code_object(code, transform)
File "/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/pytorch/torch/_dynamo/convert_frame.py", line 390, in transform
tracer.run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1468, in run
super().run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 352, in run
and self.step()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 322, in step
getattr(self, inst.opname)(inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 174, in wrapper
return inner_fn(self, inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 766, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 264, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/pytorch/torch/_dynamo/variables/nn_module.py", line 209, in call_function
**options,
File "/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
return fn(*args, **kwargs)
File "/pytorch/torch/_dynamo/utils.py", line 92, in time_wrapper
r = func(*args, **kwargs)
File "/pytorch/torch/_dynamo/convert_frame.py", line 356, in _convert_frame_assert
frame,
File "/pytorch/torch/_dynamo/convert_frame.py", line 402, in _compile
out_code = transform_code_object(code, transform)
File "/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/pytorch/torch/_dynamo/convert_frame.py", line 390, in transform
tracer.run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1468, in run
super().run()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 352, in run
and self.step()
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 322, in step
getattr(self, inst.opname)(inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 174, in wrapper
return inner_fn(self, inst)
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 766, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/pytorch/torch/_dynamo/symbolic_convert.py", line 264, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/pytorch/torch/_dynamo/variables/nn_module.py", line 209, in call_function
**options,
File "/pytorch/torch/_dynamo/variables/tensor.py", line 201, in create
example_value = _get_fake_value(proxy.node, tx)
File "/pytorch/torch/_dynamo/variables/tensor.py", line 145, in _get_fake_value
raise TorchRuntimeError() from e
torch._dynamo.exc.TorchRuntimeError:

from user code:
File "myscripts/repro_maxpool.py", line 14, in forward
out = self.pool(out)

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True

Minified repro

repro_maxpool.py

from torch import nn
import torch
import torch._dynamo as dynamo
import torch_xla.core.xla_model as xm

class MaxPoolModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=2)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        out = self.conv(x)
        out = self.pool(out)
        return out

    def get_random_inputs(self):
        return (torch.rand(2, 3, 10, 10),)

xla_dev = xm.xla_device()
model = MaxPoolModule().to(device=xla_dev)
inputs = map(lambda x: x.to(device=xla_dev), model.get_random_inputs())
dynamo.optimize(lambda gm, _: gm)(lambda: model(*inputs))()

Command:

GPU_NUM_DEVICES=1 python repro_maxpool.py

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtriaged

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions