Skip to content

[Bug] concat_mla_absorb_q kernel fails for long input #12250

@bingps

Description

@bingps

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

The concat_mla_absorb_q kernel fails for long inputs (~30K), like

============ Test _concat_mla_absorb_q_general on S=16384 ============
Torch  cat: median 3220 us, min 3218 us, max 3224 us
concat_mla: median 1370 us, min 1368 us, max 1372 us
============ Test _concat_mla_absorb_q_general on S=29000 ============
Torch  cat: median 5678 us, min 5676 us, max 5679 us
concat_mla: median 2396 us, min 2394 us, max 2397 us
============ Test _concat_mla_absorb_q_general on S=30000 ============
Traceback (most recent call last):
  File "/root/sglang-fork/python/test_cat.py", line 18, in <module>
    q = _concat_mla_absorb_q_general(q_nope, q_rope)
  File "/root/sglang-fork/python/sglang/srt/layers/attention/trtllm_mla_backend.py", line 1068, in _concat_mla_absorb_q_general
    return concat_mla_absorb_q(q_nope, q_rope)
  File "/usr/local/lib/python3.10/dist-packages/sgl_kernel/elementwise.py", line 392, in concat_mla_absorb_q
    torch.ops.sgl_kernel.concat_mla_absorb_q(a, b, out)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
RuntimeError: CUDA kernel launch failed: an illegal memory access was encountered

Reproduction

Scripts for reproduction

import torch
import triton
from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general


if __name__ == "__main__":
    H = 128
    D_NOPE = 512
    D_ROPE = 64

    for S in [16384, 29000, 30000]:
        print(f"============ Test _concat_mla_absorb_q_general on {S=} ============")
        q_nope = torch.randn(S, H, D_NOPE, dtype=torch.bfloat16, device="cuda")
        q_rope = torch.randn(S, H, D_ROPE, dtype=torch.bfloat16, device="cuda")

        q_ref = torch.cat([q_nope, q_rope], dim=-1)
        q = _concat_mla_absorb_q_general(q_nope, q_rope)
        assert torch.equal(q_ref, q)

        quantiles = [0.5, 0.2, 0.8]
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch.cat([q_nope, q_rope], dim=-1),
            quantiles=quantiles,
        )
        print(f"Torch  cat: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")

        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: _concat_mla_absorb_q_general(q_nope, q_rope),
            quantiles=quantiles,
        )
        print(f"concat_mla: median {ms*1000:4.0f} us, min {min_ms*1000:4.0f} us, max {max_ms*1000:4.0f} us")

Environment

Python: 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H20
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 535.183.06
PyTorch: 2.8.0+cu128
sglang: 0.5.4.post1
sgl_kernel: 0.3.16.post4
flashinfer_python: 0.4.1
triton: 3.4.0
transformers: 4.57.1
torchao: 0.9.0
numpy: 1.26.4
aiohttp: 3.11.13
fastapi: 0.115.11
hf_transfer: 0.1.9
huggingface_hub: 0.35.3
interegular: 0.3.3
modelscope: 1.23.2
orjson: 3.10.15
outlines: 0.1.11
packaging: 24.2
psutil: 7.0.0
pydantic: 2.12.0
python-multipart: 0.0.20
pyzmq: 26.2.1
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.7.2
xgrammar: 0.1.25
openai: 1.99.1
tiktoken: 0.9.0
anthropic: 0.49.0
litellm: 1.63.2
decord2: 2.0.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions