Skip to content

[torchbench] dlrm fails to run on training. #6008

@ysiraichi

Description

@ysiraichi

🐛 Bug

Running the upstreamed benchmarking scripts with the following command results in an unexpected error.

python xla/benchmarks/experiment_runner.py \
       --suite-name torchbench \
       --accelerator cuda \
       --xla PJRT --xla None \
       --dynamo openxla --dynamo None \
       --test train \
       --repeat 30 --iterations-per-run 5 \
       --print-subprocess \
       --no-resume -k dlrm
Traceback (most recent call last):
  File ""xla/benchmarks/experiment_runner.py", line 601, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 597, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 65, in run
    self.run_single_experiment(experiment_config, model_config)
  File "xla/benchmarks/experiment_runner.py", line 161, in run_single_experiment
    run_metrics, output = self.timed_run(benchmark_experiment,
  File "xla/benchmarks/experiment_runner.py", line 328, in timed_run
    output = loop()
  File "xla/benchmarks/experiment_runner.py", line 310, in loop
    output = benchmark_model.model_iter_fn(
  File "torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "xla/benchmarks/torchbench_model.py", line 247, in train
    super().train(inputs, collect_full_output=collect_full_output)
  File "xla/benchmarks/benchmark_model.py", line 142, in train
    self._optimizer_zero_grad()
  File "xla/benchmarks/benchmark_model.py", line 145, in resume_in_train
    loss.backward()
  File "torch/_tensor.py", line 503, in backward
    torch.autograd.backward(
  File "torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "torch/_functorch/aot_autograd.py", line 4201, in backward
    out = call_compiled_backward()
  File "torch/_functorch/aot_autograd.py", line 4167, in call_compiled_backward
    out = call_func_with_args(
  File "torch/_functorch/aot_autograd.py", line 2016, in call_func_with_args
    out = normalize_as_list(f(args))
  File "torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "torch/_functorch/aot_autograd.py", line 1992, in g
    return f(*args)
  File "torch/_dynamo/backends/torchxla.py", line 49, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "xla/torch_xla/core/dynamo_bridge.py", line 517, in extract_compiled_graph
    collector.run(*xla_args)
  File "torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "xla/torch_xla/core/dynamo_bridge.py", line 431, in run_node
    result = super().run_node(n)
  File "torch/fx/interpreter.py", line 195, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "torch/fx/interpreter.py", line 267, in call_function
    return target(*args, **kwargs)
  File "torch/_ops.py", line 509, in __call__
    return self._op(*args, **kwargs or {})
NotImplementedError: Could not run 'aten::_sparse_coo_tensor_with_dims_and_tensors' with arguments from the 'SparseXLA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_sparse_coo_tensor_with_dims_and_tensors' is only available for these backends: [XLA, Meta, SparseCPU, SparseCUDA, SparseMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXLA, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

XLA: registered at torch_xla/csrc/aten_cpu_fallback.cpp:51 [backend fallback]
Meta: registered at build/aten/src/ATen/RegisterMeta.cpp:26984 [kernel]
SparseCPU: registered at build/aten/src/ATen/RegisterSparseCPU.cpp:1387 [kernel]
SparseCUDA: registered at build/aten/src/ATen/RegisterSparseCUDA.cpp:1573 [kernel]
SparseMeta: registered at build/aten/src/ATen/RegisterSparseMeta.cpp:249 [kernel]
BackendSelect: registered at build/aten/src/ATen/RegisterBackendSelect.cpp:807 [kernel]
Python: registered at aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at aten/src/ATen/FunctionalizeFallbackKernel.cpp:302 [backend fallback]
Named: registered at aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradCPU: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradCUDA: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradHIP: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradXLA: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradMPS: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradIPU: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradXPU: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradHPU: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradVE: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradLazy: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradMTIA: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradPrivateUse1: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradPrivateUse2: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradPrivateUse3: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradMeta: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
AutogradNestedTensor: registered at torch/csrc/autograd/generated/VariableType_2.cpp:19039 [autograd kernel]
Tracer: registered at torch/csrc/autograd/generated/TraceType_2.cpp:17346 [kernel]
AutocastCPU: fallthrough registered at aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastXLA: fallthrough registered at torch_xla/csrc/autocast_mode.cpp:25 [backend fallback]
AutocastCUDA: fallthrough registered at aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:720 [backend fallback]
BatchedNestedTensor: registered at aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:746 [backend fallback]
FuncTorchVmapMode: fallthrough registered at aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]


While executing %_sparse_coo_tensor_with_dims_and_tensors : [num_users=1] = call_function[target=torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors.default](args = (1, 1, [1000000, 64], %view_6, %view_7), kwargs = {dtype: torch.float32, layout: torch.sparse_coo, device: xla:0, pin_memory: None})
Original traceback:
  File "xla/benchmarks/benchmark_model.py", line 143, in resume_in_train
    pred = self.module(*inputs)
  File "torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 337, in forward
    return self.sequential_forward(dense_x, lS_o, lS_i)
  File "torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 349, in sequential_forward
    ly = self.apply_emb(lS_o, lS_i, self.emb_l)
  File "torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 294, in apply_emb
    V = E(sparse_index_group_batch, sparse_offset_group_batch)

Environment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions