Skip to content

Splash attention test fail randomly only in github CI #8971

@zpcore

Description

@zpcore

🐛 Bug

test_splash_attention_segment_id failure only happens in github CI randomly. I am unable to reproduce it locally. Previously we have test_splash_attention_aot_traceable fail with same error and phenomenon. My thinking is that the test may impact each other's state. Below is detailed error:

+ python3 /home/runner/_work/xla/xla/pytorch/xla/test/test_splash_attention.py
/home/runner/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:351: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
...E.
======================================================================
ERROR: test_splash_attention_segment_id (__main__.SplashAttentionTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/runner/_work/xla/xla/pytorch/xla/test/test_splash_attention.py", line 33, in wrapper
    result = func(*args, **kwargs)
  File "/home/runner/_work/xla/xla/pytorch/xla/test/test_splash_attention.py", line 79, in setUp
    loss.backward()
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File "/home/runner/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 353, in backward
    _engine_run_backward(
  File "/home/runner/.local/lib/python3.10/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/runner/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/_internal/jax_workarounds.py", line 26, in wrapper
    return func(*args, **kwargs)
  File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/experimental/custom_kernel.py", line 842, in backward
    grad_q, grad_k, grad_v, grad_ab = fa_custom_backward(
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 681, in __call__
    return self._opoverload(*args, **kwargs)
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_ops.py", line 776, in __call__
    return self._op(*args, **kwargs)
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/autograd.py", line 111, in autograd_impl
    result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_ops.py", line 781, in redispatch
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 344, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 850, in _fn
    return fn(*args, **kwargs)
  File "/home/runner/.local/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 377, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/experimental/custom_kernel.py", line 695, in fa_custom_backward
    res = fa_backward_callable(grad_output, q, k, v, o, l, m, q_segment_ids,
  File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/experimental/custom_kernel.py", line 84, in wrapped
    new_a = xs.enable_manual_sharding(a, spec, mesh=mesh).global_tensor
  File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py", line 521, in enable_manual_sharding
    t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec)
  File "/home/runner/.local/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py", line 589, in mark_sharding
    annotate_func(unwrap_sharded_tensor(t), op_sharding)
RuntimeError: torch_xla/csrc/xla_sharding_util.cpp:818 : Check failed: (xtensor->CurrentDataHandle() && xtensor->CurrentDataHandle()->HasValue()) || device_data_node != nullptr 
*** Begin stack trace ***
	tsl::CurrentStackTrace[abi:cxx11]()
	torch_xla::ShardingUtil::XlaMarkSharding(at::Tensor const&, xla::OpSharding)
	
	
	
	
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	
	
	
	
	
	
	_PyObject_MakeTpCall
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	
	
	torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args const&, pybind11::kwargs const&, std::optional<c10::DispatchKey>)
	torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args const&, pybind11::kwargs const&, bool, std::optional<c10::DispatchKey>)
	
	
	
	
	_PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyObject_FastCallDictTstate
	_PyObject_Call_Prepend
	
	_PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyObject_FastCallDictTstate
	_PyObject_Call_Prepend
	
	_PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	
	
	torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&)
	torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&)
	torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool)
	
	
	
	clone
*** End stack trace ***
Cannot shard tensor. Data does not present on any device.

To Reproduce

Unable to reproduce locally.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions