❓ Questions and Help
I am encountering an issue with an SPMD LLAMA forward pass script. When running the script on a Google Cloud v3-8 instance using torch and torch_xla nightly version (post-20240527), I receive the following error:
File "/home/huzama/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 184, in apply_rotary_pos_emb
k_embed = (k * cos) + (rotate_half(k) * sin)
RuntimeError: torch_xla/csrc/aten_xla_type.cpp:161 : Check failed: it != inputs_.end() *** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::XLANativeFunctions::mul(at::Tensor const&, at::Tensor const&)
at::_ops::mul_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)
at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&)
PyNumber_Multiply
_PyEval_EvalFrameDefault
Interestingly, running the exact same script within a Jupyter Notebook in the same environment works perfectly.
Here is the script:
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from transformers import LlamaForCausalLM
xr.use_spmd()
device = "xla"
model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
model = model.to(device)
spmd_mesh = xs.Mesh(np.array(range(8)), (4, 2), ("data", "model"))
for name, param in model.named_parameters():
if len(param.shape) == 2:
xs.mark_sharding(param, spmd_mesh, ("model", "data"))
model(torch.randint(0, 50256, (1, 512)).to(device))
xm.mark_step()
To reiterate, this error occurs when running as a Python script, but not within a Jupyter Notebook.
Details:
- Environment: Google Cloud v3-8 instance
- Python Version: 3.10
- Libraries:
torch (nightly version after 20240527)
torch_xla (nightly version after 20240527)
- Model: LlamaForCausalLM from the
transformers library
❓ Questions and Help
I am encountering an issue with an SPMD LLAMA forward pass script. When running the script on a Google Cloud v3-8 instance using
torchandtorch_xlanightly version (post-20240527), I receive the following error:Interestingly, running the exact same script within a Jupyter Notebook in the same environment works perfectly.
Here is the script:
To reiterate, this error occurs when running as a Python script, but not within a Jupyter Notebook.
Details:
torch(nightly version after20240527)torch_xla(nightly version after20240527)transformerslibrary