Skip to content

Error: Check Failed: it != inputs_.end() When Using torch and torch_xla Nightly Version (Post-20240527) with SPMD #7266

@huzama

Description

@huzama

❓ Questions and Help

I am encountering an issue with an SPMD LLAMA forward pass script. When running the script on a Google Cloud v3-8 instance using torch and torch_xla nightly version (post-20240527), I receive the following error:

File "/home/huzama/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 184, in apply_rotary_pos_emb
    k_embed = (k * cos) + (rotate_half(k) * sin)
RuntimeError: torch_xla/csrc/aten_xla_type.cpp:161 : Check failed: it != inputs_.end() *** Begin stack trace ***
        tsl::CurrentStackTrace()
        torch_xla::XLANativeFunctions::mul(at::Tensor const&, at::Tensor const&)

        at::_ops::mul_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)
        at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&)
        PyNumber_Multiply
        _PyEval_EvalFrameDefault

Interestingly, running the exact same script within a Jupyter Notebook in the same environment works perfectly.

Here is the script:

import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from transformers import LlamaForCausalLM

xr.use_spmd()

device = "xla"

model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
model = model.to(device)

spmd_mesh = xs.Mesh(np.array(range(8)), (4, 2), ("data", "model"))

for name, param in model.named_parameters():
    if len(param.shape) == 2:
        xs.mark_sharding(param, spmd_mesh, ("model", "data"))

model(torch.randint(0, 50256, (1, 512)).to(device))

xm.mark_step()

To reiterate, this error occurs when running as a Python script, but not within a Jupyter Notebook.

Details:

  • Environment: Google Cloud v3-8 instance
  • Python Version: 3.10
  • Libraries:
    • torch (nightly version after 20240527)
    • torch_xla (nightly version after 20240527)
  • Model: LlamaForCausalLM from the transformers library

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions