Running a few torchbench benchmarks, using dynamo+openxla backend, ends up in an assertion failure:
Traceback (most recent call last):
File "xla/benchmarks/experiment_runner.py", line 601, in <module>
main()
File "xla/benchmarks/experiment_runner.py", line 597, in main
runner.run()
File "xla/benchmarks/experiment_runner.py", line 65, in run
self.run_single_experiment(experiment_config, model_config)
File "xla/benchmarks/experiment_runner.py", line 161, in run_single_experiment
run_metrics, output = self.timed_run(benchmark_experiment,
File "xla/benchmarks/experiment_runner.py", line 328, in timed_run
output = loop()
File "xla/benchmarks/experiment_runner.py", line 310, in loop
output = benchmark_model.model_iter_fn(
File "torch/_dynamo/eval_frame.py", line 410, in _fn
return fn(*args, **kwargs)
File "xla/benchmarks/benchmark_model.py", line 154, in eval
pred = self.module(*inputs)
File "torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
File "/lib/python3.10/site-packages/detectron2-0.6-py3.10-linux-x86_64.egg/detectron2/modeling/meta_arch/rcnn.py", line 150, in forward
return self.inference(batched_inputs)
File "/lib/python3.10/site-packages/detectron2-0.6-py3.10-linux-x86_64.egg/detectron2/modeling/meta_arch/rcnn.py", line 203, in inference
images = self.preprocess_image(batched_inputs)
File "/lib/python3.10/site-packages/detectron2-0.6-py3.10-linux-x86_64.egg/detectron2/modeling/meta_arch/rcnn.py", line 229, in preprocess_image
images = ImageList.from_tensors(
File "/lib/python3.10/site-packages/detectron2-0.6-py3.10-linux-x86_64.egg/detectron2/structures/image_list.py", line 58, in from_tensors
@staticmethod
File "torch/_dynamo/eval_frame.py", line 410, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "torch/_functorch/aot_autograd.py", line 4832, in forward
return compiled_fn(full_args)
File "torch/_functorch/aot_autograd.py", line 1948, in g
return f(*args)
File "torch/_functorch/aot_autograd.py", line 3045, in runtime_wrapper
all_outs = call_func_with_args(
File "torch/_functorch/aot_autograd.py", line 1972, in call_func_with_args
out = normalize_as_list(f(args))
File "torch/_functorch/aot_autograd.py", line 2075, in rng_functionalization_wrapper
return compiled_fw(args)
File "torch/_functorch/aot_autograd.py", line 1948, in g
return f(*args)
File "torch/_dynamo/backends/torchxla.py", line 49, in fwd
compiled_graph = bridge.extract_compiled_graph(model, args)
File "xla/torch_xla/core/dynamo_bridge.py", line 539, in extract_compiled_graph
extract_internal(fused_module), node.args, None)
File "xla/torch_xla/core/dynamo_bridge.py", line 337, in extract_internal
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
File "xla/torch_xla/core/dynamo_bridge.py", line 211, in extract_graph_helper
assert all(
AssertionError: All tensors should be on xla
🐛 Bug
Running a few torchbench benchmarks, using dynamo+openxla backend, ends up in an assertion failure:
To Reproduce
Affected Benchmarks
Environment