Checklist
Describe the bug
The concat_mla_absorb_q kernel fails for long inputs (~30K), like
============ Test _concat_mla_absorb_q_general on S=16384 ============
Torch cat: median 3220 us, min 3218 us, max 3224 us
concat_mla: median 1370 us, min 1368 us, max 1372 us
============ Test _concat_mla_absorb_q_general on S=29000 ============
Torch cat: median 5678 us, min 5676 us, max 5679 us
concat_mla: median 2396 us, min 2394 us, max 2397 us
============ Test _concat_mla_absorb_q_general on S=30000 ============
Traceback (most recent call last):
File "/root/sglang-fork/python/test_cat.py", line 18, in <module>
q = _concat_mla_absorb_q_general(q_nope, q_rope)
File "/root/sglang-fork/python/sglang/srt/layers/attention/trtllm_mla_backend.py", line 1068, in _concat_mla_absorb_q_general
return concat_mla_absorb_q(q_nope, q_rope)
File "/usr/local/lib/python3.10/dist-packages/sgl_kernel/elementwise.py", line 392, in concat_mla_absorb_q
torch.ops.sgl_kernel.concat_mla_absorb_q(a, b, out)
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1243, in __call__
return self._op(*args, **kwargs)
RuntimeError: CUDA kernel launch failed: an illegal memory access was encountered
Reproduction
Scripts for reproduction
import torch
import triton
from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general
if __name__ == "__main__":
H = 128
D_NOPE = 512
D_ROPE = 64
for S in [16384, 29000, 30000]:
print(f"============ Test _concat_mla_absorb_q_general on {S=} ============")
q_nope = torch.randn(S, H, D_NOPE, dtype=torch.bfloat16, device="cuda")
q_rope = torch.randn(S, H, D_ROPE, dtype=torch.bfloat16, device="cuda")
q_ref = torch.cat([q_nope, q_rope], dim=-1)
q = _concat_mla_absorb_q_general(q_nope, q_rope)
assert torch.equal(q_ref, q)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.cat([q_nope, q_rope], dim=-1),
quantiles=quantiles,
)
print(f"Torch cat: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _concat_mla_absorb_q_general(q_nope, q_rope),
quantiles=quantiles,
)
print(f"concat_mla: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")
Environment
Python: 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H20
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 535.183.06
PyTorch: 2.8.0+cu128
sglang: 0.5.4.post1
sgl_kernel: 0.3.16.post4
flashinfer_python: 0.4.1
triton: 3.4.0
transformers: 4.57.1
torchao: 0.9.0
numpy: 1.26.4
aiohttp: 3.11.13
fastapi: 0.115.11
hf_transfer: 0.1.9
huggingface_hub: 0.35.3
interegular: 0.3.3
modelscope: 1.23.2
orjson: 3.10.15
outlines: 0.1.11
packaging: 24.2
psutil: 7.0.0
pydantic: 2.12.0
python-multipart: 0.0.20
pyzmq: 26.2.1
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.7.2
xgrammar: 0.1.25
openai: 1.99.1
tiktoken: 0.9.0
anthropic: 0.49.0
litellm: 1.63.2
decord2: 2.0.0
Checklist
Describe the bug
The
concat_mla_absorb_qkernel fails for long inputs (~30K), likeReproduction
Scripts for reproduction
Environment
Python: 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H20
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 535.183.06
PyTorch: 2.8.0+cu128
sglang: 0.5.4.post1
sgl_kernel: 0.3.16.post4
flashinfer_python: 0.4.1
triton: 3.4.0
transformers: 4.57.1
torchao: 0.9.0
numpy: 1.26.4
aiohttp: 3.11.13
fastapi: 0.115.11
hf_transfer: 0.1.9
huggingface_hub: 0.35.3
interegular: 0.3.3
modelscope: 1.23.2
orjson: 3.10.15
outlines: 0.1.11
packaging: 24.2
psutil: 7.0.0
pydantic: 2.12.0
python-multipart: 0.0.20
pyzmq: 26.2.1
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.7.2
xgrammar: 0.1.25
openai: 1.99.1
tiktoken: 0.9.0
anthropic: 0.49.0
litellm: 1.63.2
decord2: 2.0.0