Running a few torchbench benchmarks, using dynamo+openxla backend, ends up in an assertion failure:
Traceback (most recent call last):
File "torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "benchmarks/dynamo/torchbench.py", line 544, in forward_pass
return mod(*inputs)
File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "torchbenchmark/models/cm3leon_generate/model.py", line 1113, in forward
def forward(self, src_tokens):
File "torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "torch/_functorch/aot_autograd.py", line 4963, in forward
return compiled_fn(full_args)
File "torch/_functorch/aot_autograd.py", line 2017, in g
return f(*args)
File "torch/_functorch/aot_autograd.py", line 3164, in runtime_wrapper
all_outs = call_func_with_args(
File "torch/_functorch/aot_autograd.py", line 2041, in call_func_with_args
out = normalize_as_list(f(args))
File "torch/_functorch/aot_autograd.py", line 2145, in rng_functionalization_wrapper
return compiled_fw(args)
File "torch/_functorch/aot_autograd.py", line 2017, in g
return f(*args)
File "torch/_dynamo/backends/torchxla.py", line 51, in fwd
return compiled_graph(*args)
File "torch/fx/graph_module.py", line 736, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "torch/fx/graph_module.py", line 315, in __call__
raise e
File "torch/fx/graph_module.py", line 302, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.5", line 5, in forward
File "xla/torch_xla/core/dynamo_bridge.py", line 387, in optimized_mod
res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)
RuntimeError: torch_xla/csrc/xla_graph_executor.cpp:625 : Check failed: cachedComputation
🐛 Bug
Running a few torchbench benchmarks, using dynamo+openxla backend, ends up in an assertion failure:
Stack Trace
Affected Benchmarks
Environment