Skip to content

[torchbench] llama failing when executed with bfloat16 data-type. #6648

@ysiraichi

Description

@ysiraichi

🐛 Bug

After #6518, llama benchmark started failing with the following error:

Traceback (most recent call last):
  File "xla/benchmarks/experiment_runner.py", line 945, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 941, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 61, in run
    self.run_single_config()
  File "xla/benchmarks/experiment_runner.py", line 256, in run_single_config
    metrics, last_output = self.run_once_and_gather_metrics(
  File "xla/benchmarks/experiment_runner.py", line 345, in run_once_and_gather_metrics
    output, _ = loop(iter_fn=self._default_iter_fn)
  File "xla/benchmarks/experiment_runner.py", line 302, in loop
    output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
  File "xla/benchmarks/experiment_runner.py", line 218, in _default_iter_fn
    output = benchmark_model.model_iter_fn(
  File "xla/benchmarks/benchmark_model.py", line 170, in eval
    pred = self.module(*inputs)
  File "torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "benchmark/torchbenchmark/models/llama/model.py", line 225, in forward
    h = self.tok_embeddings(tokens)
  File "torch/nn/modules/module.py", line 1529, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
  File "torch/nn/functional.py", line 2233, in embedding
    return handle_torch_function(
  File "torch/overrides.py", line 1619, in handle_torch_function
    result = mode.__torch_function__(public_api, types, args, kwargs)
  File "torch/utils/_device.py", line 78, in __torch_function__
    return func(*args, **kwargs)
  File "torch/nn/functional.py", line 2264, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: torch_xla/csrc/tensor_ops.cpp:248 : Check failed: indices->dtype() == at::ScalarType::Long (Int vs. L
ong)
Stack Trace
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        torch_xla::tensor_ops::Embedding(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
        torch_xla::tensor_methods::embedding(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
        torch_xla::XLANativeFunctions::embedding_symint(at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)


        at::_ops::embedding::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)


        at::_ops::embedding::call(at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)


        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault


        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall

        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyObject_FastCallDict
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault


        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall

        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyObject_FastCallDict
        _PyObject_Call_Prepend

        PyObject_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall

        _PyEval_EvalFrameDefault


        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        PyEval_EvalCodeEx
        PyEval_EvalCode



        PyRun_SimpleFileExFlags
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

Affected Configurations

  • Non-Dynamo Inference
  • Dynamo Inference

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: 4327d24

cc @miladm @JackCaoG

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions