+ python3 /home/runner/_work/xla/xla/pytorch/xla/test/test_splash_attention.py
/home/runner/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:351: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
warnings.warn(
...E.
======================================================================
ERROR: test_splash_attention_segment_id (__main__.SplashAttentionTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/runner/_work/xla/xla/pytorch/xla/test/test_splash_attention.py", line 33, in wrapper
result = func(*args, **kwargs)
File "/home/runner/_work/xla/xla/pytorch/xla/test/test_splash_attention.py", line 79, in setUp
loss.backward()
File "/home/runner/.local/lib/python3.10/site-packages/torch/_tensor.py", line 648, in backward
torch.autograd.backward(
File "/home/runner/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 353, in backward
_engine_run_backward(
File "/home/runner/.local/lib/python3.10/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/runner/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/_internal/jax_workarounds.py", line 26, in wrapper
return func(*args, **kwargs)
File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/experimental/custom_kernel.py", line 842, in backward
grad_q, grad_k, grad_v, grad_ab = fa_custom_backward(
File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 681, in __call__
return self._opoverload(*args, **kwargs)
File "/home/runner/.local/lib/python3.10/site-packages/torch/_ops.py", line 776, in __call__
return self._op(*args, **kwargs)
File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/autograd.py", line 111, in autograd_impl
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
File "/home/runner/.local/lib/python3.10/site-packages/torch/_ops.py", line 781, in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 344, in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
File "/home/runner/.local/lib/python3.10/site-packages/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
File "/home/runner/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 850, in _fn
return fn(*args, **kwargs)
File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 377, in wrapped_fn
return fn(*args, **kwargs)
File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/experimental/custom_kernel.py", line 695, in fa_custom_backward
res = fa_backward_callable(grad_output, q, k, v, o, l, m, q_segment_ids,
File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/experimental/custom_kernel.py", line 84, in wrapped
new_a = xs.enable_manual_sharding(a, spec, mesh=mesh).global_tensor
File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py", line 521, in enable_manual_sharding
t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec)
File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py", line 589, in mark_sharding
annotate_func(unwrap_sharded_tensor(t), op_sharding)
RuntimeError: torch_xla/csrc/xla_sharding_util.cpp:818 : Check failed: (xtensor->CurrentDataHandle() && xtensor->CurrentDataHandle()->HasValue()) || device_data_node != nullptr
*** Begin stack trace ***
tsl::CurrentStackTrace[abi:cxx11]()
torch_xla::ShardingUtil::XlaMarkSharding(at::Tensor const&, xla::OpSharding)
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args const&, pybind11::kwargs const&, std::optional<c10::DispatchKey>)
torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args const&, pybind11::kwargs const&, bool, std::optional<c10::DispatchKey>)
_PyObject_Call
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_Call
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&)
torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&)
torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool)
clone
*** End stack trace ***
Cannot shard tensor. Data does not present on any device.
Unable to reproduce locally.
🐛 Bug
test_splash_attention_segment_idfailure only happens in github CI randomly. I am unable to reproduce it locally. Previously we havetest_splash_attention_aot_traceablefail with same error and phenomenon. My thinking is that the test may impact each other's state. Below is detailed error:To Reproduce
Unable to reproduce locally.