Skip to content

[Bug] Shape mismatch in Indexer with GLM5 pp2 when enable hicache #20529

@Zhangmj0621

Description

@Zhangmj0621

Checklist

  • I searched related issues but found no solution.
  • The bug persists in the latest version.
  • Issues without environment info and a minimal reproducible demo are hard to resolve and may receive no feedback.
  • If this is not a bug report but a general question, please start a discussion at https://github.com/sgl-project/sglang/discussions. Otherwise, it will be closed.
  • Please use English. Otherwise, it will be closed.

Describe the bug

When I use sglang+hicache+mooncake for GLM5, with prefill tp8pp2 in H100 and dp16ep16 in H20, I encounter shape mismatch of positions and query/key in Indexer each time when I run hicache multi-turn benchmark at around 700/800 requests.
I change my sglang codebase to the latest one in 2026-3-13, but it still not works. When I change pp-size to 1, this issue disappear. So I think this issue is directly related with pp. I also noticed that when I start sglang server, there have logs [2026-03-13 06:39:44] Transformers version 5.3.0 is used for model type glm_moe_dsa. If you experience issues related to RoPE parameters, they may be due to incompatibilities between Transformers >=5.0.0 and some models. You can try downgrading to transformers==4.57.1 as a workaround.. However, since GLM5 use AutoTokenizer and only can be used in Transformer v5, I can't validate if this is a bug that can be fixed by downgrade transformer version.
This issue is related to #20341.
Error logs can be seen as below.

[2026-03-12 10:51:20 PP1 TP2] Scheduler hit an exception: Traceback (most recent call last):
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/managers/scheduler.py", line 3165, in run_scheduler_process
    scheduler.event_loop_pp_disagg_prefill()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/managers/scheduler_pp_mixin.py", line 245, in event_loop_pp_disagg_prefill
    result, self.launch_event = self._pp_launch_batch(
                                ^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/managers/scheduler_pp_mixin.py", line 1064, in _pp_launch_batch
    result = self.run_batch(self.cur_batch, pp_proxy_tensors)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/managers/scheduler.py", line 2381, in run_batch
    batch_result = self.model_worker.forward_batch_generation(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/managers/tp_worker.py", line 455, in forward_batch_generation
    out = self.model_runner.forward(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/model_executor/model_runner.py", line 2390, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/model_executor/model_runner.py", line 2489, in _forward_raw
    ret, can_run_graph = self.forward_extend(
                         ^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/model_executor/model_runner.py", line 2327, in forward_extend
    self.model.forward(
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/models/deepseek_v2.py", line 2912, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/models/deepseek_v2.py", line 2723, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/models/deepseek_v2.py", line 2388, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/models/deepseek_v2.py", line 1365, in forward
    s = self.forward_prepare(
        ^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/models/deepseek_v2.py", line 1419, in forward_prepare
    inner_state = self.forward_absorb_prepare(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/models/deepseek_v2.py", line 1628, in forward_absorb_prepare
    topk_indices = self.indexer(
                   ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/layers/utils/multi_platform.py", line 71, in forward
    return self._forward_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/layers/attention/nsa/nsa_indexer.py", line 991, in forward_cuda
    query, key = self._get_q_k_bf16(
                 ^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/layers/attention/nsa/nsa_indexer.py", line 297, in _get_q_k_bf16
    q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/inspire/hdd/global_user/huxiaohe-p-huxiaohe/sglang/python/sglang/srt/layers/rotary_embedding.py", line 280, in forward_native
    query = query.view(num_tokens, -1, self.head_size)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[8136, -1, 64]' is invalid for input of size 16777216

Reproduction

Docker: lmsysorg/sglang:glm5-hopper
Then install latest sglang from source and upgrade transformer version.
I use sglang+hicache+mooncake, with tp16 in prefill works well but tp8pp2 failure each time in multi-turn benchmark.
My prefill script is as follows, mooncake configuration is hidden to save space:

nohup python -m sglang.launch_server \
  --model-path /path/to/model/GLM-5-FP8/  \
  --trust-remote-code \
  --disaggregation-mode prefill \
  --disaggregation-ib-device mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7 \
  --tp 8 \
  --pp-size 2 \
  --enable-hierarchical-cache \
  --hicache-ratio 1 \
  --hicache-storage-prefetch-policy timeout \
  --hicache-storage-backend mooncake \
  --hicache-write-policy write_through \
  --watchdog-timeout 1000000 \
  --mem-fraction-static 0.80 \
  --dist-init-addr x.x.x.x:20102 \
  --tool-call-parser glm47 \
  --reasoning-parser glm45 \
  --nnodes 2 \
  --node-rank 0 \
  --port 30000 \
  --host 0.0.0.0 > prefill.log &

My decode script is as follows:

export SGLANG_DG_CACHE_DIR=/path/to/deep_gemm/
nohup python -m sglang.launch_server \
  --model-path /path/to/model/GLM-5-FP8/ \
  --trust-remote-code \
  --disaggregation-mode decode \
  --disaggregation-ib-device mlx5_0,mlx5_1,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7 \
  --tp 16 \
  --dp-size 16 \
  --ep-size  16 \
  --enable-dp-attention \
  --enable-dp-lm-head \
  --ep-dispatch-algorithm dynamic \
  --enable-eplb \
  --eplb-algorithm deepseek \
  --watchdog-timeout 1000000 \
  --mem-fraction-static 0.65 \
  --dist-init-addr x.x.x.x:20102 \
  --eplb-rebalance-layers-per-chunk=29 \
  --tool-call-parser glm47 \
  --reasoning-parser glm45 \
  --nnodes 2 \
  --node-rank 0 \
  --port 30008 \
  --host 0.0.0.0 > decode.log &

My test script is as follows:

nohup python3 benchmark/hicache/bench_multiturn.py \
    --model-path /path/to/model/GLM-5-FP8/ \
    --disable-random-sample \
    --output-length 1 \
    --request-length 2048 \
    --num-clients 80 \
    --num-rounds 10 \
    --max-parallel 4 \
    --request-rate 16 \
    --ready-queue-policy random \
    --disable-auto-run \
    --host 0.0.0.0 \
    --port 8000 \
    --enable-round-barrier > bench_glm5.log 2>&1 &

Environment

Python: 3.12.3 (main, Jan 22 2026, 20:57:42) [GCC 13.3.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H100 80GB HBM3
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.9, V12.9.86
CUDA Driver Version: 560.35.03
PyTorch: 2.9.1+cu129
sglang: 0.0.0.dev1+g93fca0bbc
sgl_kernel: 0.3.21
flashinfer_python: 0.6.2
flashinfer_cubin: 0.6.2
flashinfer_jit_cache: 0.6.2+cu129
triton: 3.5.1
transformers: 5.2.0.dev0
torchao: 0.9.0
numpy: 2.4.2
aiohttp: 3.13.3
fastapi: 0.128.7
hf_transfer: 0.1.9
huggingface_hub: 1.4.1
interegular: 0.3.3
modelscope: 1.34.0
orjson: 3.11.7
outlines: 0.1.11
packaging: 26.0
psutil: 7.2.2
pydantic: 2.12.5
python-multipart: 0.0.22
pyzmq: 27.1.0
uvicorn: 0.40.0
uvloop: 0.22.1
vllm: Module Not Found
xgrammar: 0.1.27
openai: 2.6.1
tiktoken: 0.12.0
anthropic: 0.79.0
litellm: Module Not Found
decord2: 3.0.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6 NIC7 NIC8 NIC9 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 PIX NODE NODE NODE NODE NODE SYS SYS SYS SYS 0-47,96-143 0 N/A
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 NODE PIX NODE NODE NODE NODE SYS SYS SYS SYS 0-47,96-143 0 N/A
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 NODE NODE PIX NODE NODE NODE SYS SYS SYS SYS 0-47,96-143 0 N/A
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 NODE NODE NODE NODE NODE PIX SYS SYS SYS SYS 0-47,96-143 0 N/A
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 SYS SYS SYS SYS SYS SYS PIX NODE NODE NODE 48-95,144-191 1 N/A
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 SYS SYS SYS SYS SYS SYS NODE PIX NODE NODE 48-95,144-191 1 N/A
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 SYS SYS SYS SYS SYS SYS NODE NODE PIX NODE 48-95,144-191 1 N/A
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X SYS SYS SYS SYS SYS SYS NODE NODE NODE PIX 48-95,144-191 1 N/A
NIC0 PIX NODE NODE NODE SYS SYS SYS SYS X NODE NODE NODE NODE NODE SYS SYS SYS SYS
NIC1 NODE PIX NODE NODE SYS SYS SYS SYS NODE X NODE NODE NODE NODE SYS SYS SYS SYS
NIC2 NODE NODE PIX NODE SYS SYS SYS SYS NODE NODE X NODE NODE NODE SYS SYS SYS SYS
NIC3 NODE NODE NODE NODE SYS SYS SYS SYS NODE NODE NODE X PIX NODE SYS SYS SYS SYS
NIC4 NODE NODE NODE NODE SYS SYS SYS SYS NODE NODE NODE PIX X NODE SYS SYS SYS SYS
NIC5 NODE NODE NODE PIX SYS SYS SYS SYS NODE NODE NODE NODE NODE X SYS SYS SYS SYS
NIC6 SYS SYS SYS SYS PIX NODE NODE NODE SYS SYS SYS SYS SYS SYS X NODE NODE NODE
NIC7 SYS SYS SYS SYS NODE PIX NODE NODE SYS SYS SYS SYS SYS SYS NODE X NODE NODE
NIC8 SYS SYS SYS SYS NODE NODE PIX NODE SYS SYS SYS SYS SYS SYS NODE NODE X NODE
NIC9 SYS SYS SYS SYS NODE NODE NODE PIX SYS SYS SYS SYS SYS SYS NODE NODE NODE X

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

NIC Legend:

NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_4
NIC5: mlx5_5
NIC6: mlx5_6
NIC7: mlx5_7
NIC8: mlx5_8
NIC9: mlx5_9

ulimit soft: 1048576

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions