Skip to content

[HiCache & HybridModel] mooncake backend support DSA & mamba model#21259

Merged
xiezhq-hermann merged 50 commits intosgl-project:mainfrom
antgroup:hicache/hicache-refactor4
Apr 14, 2026
Merged

[HiCache & HybridModel] mooncake backend support DSA & mamba model#21259
xiezhq-hermann merged 50 commits intosgl-project:mainfrom
antgroup:hicache/hicache-refactor4

Conversation

@huangtingwei9988
Copy link
Copy Markdown
Collaborator

@huangtingwei9988 huangtingwei9988 commented Mar 24, 2026

Motivation

Added support for the Mooncake backend. Supports both Mamba and DSA models.

Roadmap:#21846

Modifications

Accuracy Tests

mamba

export MOONCAKE_MASTER=10.13.3.162:50051
export MOONCAKE_PROTOCOL="rdma"
export MOONCAKE_DEVICE="mlx5_bond_0"
export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
python3 -m sglang.launch_server \
      --model-path /root/.cache/modelscope/hub/models/Qwen/Qwen3.5-9B \
      --trust-remote-code \
      --mamba-scheduler-strategy extra_buffer \
      --port 8188 \
      --max-mamba-cache-size 1000 \
      --host 0.0.0.0 \
      --max-total-tokens 655360 \
      --chunked-prefill-size 65536 \
      --tp-size 4 \
      --reasoning-parser qwen3 \
      --page-size 64 \
      --mem-fraction-static 0.88 \
      --cuda-graph-max-bs 64 \
      --hicache-storage-prefetch-policy wait_complete \
      --enable-hierarchical-cache --hicache-size 20  --hicache-io-backend direct --hicache-mem-layout page_first_direct  --hicache-write-policy write_through  --hicache-storage-backend mooncake

DSA (DeepSeek-V3.2-Exp)

export MOONCAKE_MASTER=10.13.3.162:50051
export MOONCAKE_PROTOCOL="rdma"
export MOONCAKE_DEVICE="mlx5_bond_0"
export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
python3 -m sglang.launch_server \
      --model-path /data/nas/yongke.zyk/model_hub/DeepSeek-V3.2-Exp \
      --trust-remote-code \
      --port 8188 \
      --host 0.0.0.0 \
      --context-length 65536 \
      --chunked-prefill-size 65536 \
      --tp-size 8 \
      --page-size 64 \
      --mem-fraction-static 0.92 \
      --cuda-graph-max-bs 32 \
      --hicache-storage-prefetch-policy timeout \
      --enable-hierarchical-cache --hicache-size 10  --hicache-io-backend direct --hicache-mem-layout page_first_direct  --hicache-write-policy write_through  --hicache-storage-backend mooncake

mamba

gsm8k (first round)

#python3 bench_sglang.py --num-questions 1319 --port 8188
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [02:36<00:00,  8.44it/s]
Accuracy: 0.901
Invalid: 0.000
Latency: 156.304 s
Output throughput: 981.804 token/s

second round (with flush_cache)

#python3 bench_sglang.py --num-questions 1319 --port 8188
100%|████████████████████████████████████████████████████████████| 1319/1319 [02:15<00:00,  9.75it/s]
Accuracy: 0.900
Invalid: 0.000
Latency: 135.285 s
Output throughput: 1132.280 token/s

mmlu

#python3 -m sglang.test.run_eval \
>   --port 8188 \
>   --eval-name mmlu \
>   --num-examples 1369 \
>   --num-threads 64 \
>   --chat-template-kwargs '{"enable_thinking": false}'

ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=2048 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'enable_thinking': False}}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1369/1369 [03:54<00:00,  5.84it/s]
Total latency: 234.265 s
Score: 0.847
[root@gpulingjun010013003162.et117 /home/shenghai.htw]
#python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:8188
[root@gpulingjun010013003162.et117 /home/shenghai.htw]
#python3 -m sglang.test.run_eval   --port 8188   --eval-name mmlu   --num-examples 1369   --num-threads 64   --chat-template-kwargs '{"enable_thinking": false}'
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=2048 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'enable_thinking': False}}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1369/1369 [03:33<00:00,  6.40it/s]
Total latency: 213.934 s
Score: 0.847

DSA

gsm8k

[root@gpulingjun010013003162.et117 /home/shenghai.htw]
#python3 bench_sglang.py --num-questions 1319 --port 8188 --parallel 8
100%|████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [07:20<00:00,  2.99it/s]
Accuracy: 0.960
Invalid: 0.000
Latency: 440.686 s
Output throughput: 276.485 token/s

[root@gpulingjun010013003162.et117 /home/shenghai.htw]
#python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:8188

[root@gpulingjun010013003162.et117 /home/shenghai.htw]
#python3 bench_sglang.py --num-questions 1319 --port 8188 --parallel 8
100%|████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [07:16<00:00,  3.02it/s]
Accuracy: 0.960
Invalid: 0.000
Latency: 436.830 s
Output throughput: 279.166 token/s

mmlu
first round

#python3 -m sglang.test.run_eval   --port 8188   --eval-name mmlu   --num-examples 1369   --num-threads 8   --chat-template-kwargs '{"enable_thinking": false}'
████████| 1369/1369 [16:01<00:00,  1.42it/s]
Total latency: 961.278 s
Score: 0.903

flush_cache

[root@gpulingjun010013003162.et117 /home/shenghai.htw]
#python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:8188

second round

#python3 -m sglang.test.run_eval   --port 8188   --eval-name mmlu   --num-examples 1369   --num-threads 8   --chat-template-kwargs '{"enable_thinking": false}'
███████████████████████████████████████| 1369/1369 [16:06<00:00,  1.42it/s]
Total latency: 966.820 s
Score: 0.901

Performance Tests

export MOONCAKE_MASTER=10.13.3.162:50051
export MOONCAKE_PROTOCOL="rdma"
export MOONCAKE_DEVICE="mlx5_bond_0"
export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
python3 -m sglang.launch_server \
      --model-path /root/.cache/modelscope/hub/models/Qwen/Qwen3.5-9B \
      --trust-remote-code \
      --mamba-scheduler-strategy extra_buffer \
      --max-mamba-cache-size 500 \
      --port 8188 \
      --host 0.0.0.0 \
      --chunked-prefill-size 65536 \
      --tp-size 2 \
      --page-size 64 \
      --mem-fraction-static 0.90 \
      --cuda-graph-max-bs 256 \
      --hicache-storage-prefetch-policy timeout \
      --enable-hierarchical-cache \
      --hicache-ratio 1.01  \
      --enable-metrics \
      --hicache-io-backend direct \
      --hicache-mem-layout page_first_direct  \
      --reasoning-parser qwen3 \
      --hicache-write-policy write_through  \
      --hicache-storage-backend mooncake

bench serving

BATCH_SIZE=48
DATA_SIZE=1024
RANDOM_INPUT=4096
RANDOM_OUTPUT=512
WARM_UP_REQUESTS=2

for REQUEST_RATES in 5
do
python3 -m bench_serving \
      --host 10.13.3.162 --port 8188 \
      --backend sglang-oai-chat \
      --model /root/.cache/modelscope/hub/models/Qwen/Qwen3.5-9B \
      --dataset-path /data/nas/moyun.zty/data/ShareGPT_V3_unfiltered_cleaned_split.json \
      --dataset-name random \
      --num-prompt $DATA_SIZE \
      --random-input $RANDOM_INPUT \
      --random-output $RANDOM_OUTPUT \
      --random-range-ratio 1 \
      --request-rate $REQUEST_RATES \
      --max-concurrency $BATCH_SIZE \
      --warmup-requests $WARM_UP_REQUESTS
done

first round (no cache)

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    5.0       
Max request concurrency:                 48        
Successful requests:                     1024      
Benchmark duration (s):                  407.48    
Total input tokens:                      4194304   
Total generated tokens:                  524288    
Total generated tokens (retokenized):    10433     
Request throughput (req/s):              2.51      
Input token throughput (tok/s):          10293.20  
Output token throughput (tok/s):         1286.65   
Total token throughput (tok/s):          11579.85  
Concurrency:                             47.27     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   18808.30  
Median E2E Latency (ms):                 19129.65  
---------------Time to First Token----------------
Mean TTFT (ms):                          714.03    
Median TTFT (ms):                        0.00      
P99 TTFT (ms):                           16909.07  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          35.41     
Median TPOT (ms):                        37.41     
P90 TPOT (ms):                           39.26     
P99 TPOT (ms):                           40.34     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           23.64     
Median ITL (ms):                         11.36     
P90 ITL (ms):                            16.71     
P95 ITL (ms):                            27.29     
P99 ITL (ms):                            333.62    
Max ITL (ms):                            9227.86   
==================================================

flush_cache

[root@gpulingjun010013003162.et117 /home/shenghai.htw]
#python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:8188

second round

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    5.0       
Max request concurrency:                 48        
Successful requests:                     1024      
Benchmark duration (s):                  225.65    
Total input tokens:                      4194304   
Total generated tokens:                  524288    
Total generated tokens (retokenized):    9700      
Request throughput (req/s):              4.54      
Input token throughput (tok/s):          18587.75  
Output token throughput (tok/s):         2323.47   
Total token throughput (tok/s):          20911.22  
Concurrency:                             39.41     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   8685.42   
Median E2E Latency (ms):                 8205.74   
---------------Time to First Token----------------
Mean TTFT (ms):                          218.77    
Median TTFT (ms):                        0.00      
P99 TTFT (ms):                           6102.70   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          16.57     
Median TPOT (ms):                        15.99     
P90 TPOT (ms):                           27.76     
P99 TPOT (ms):                           37.50     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           17.42     
Median ITL (ms):                         10.97     
P90 ITL (ms):                            26.98     
P95 ITL (ms):                            57.64     
P99 ITL (ms):                            154.39    
Max ITL (ms):                            2248.44   
==================================================

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

hzh0425 and others added 30 commits March 12, 2026 21:31
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
…tor6

# Conflicts:
#	python/sglang/srt/mem_cache/hi_mamba_radix_cache.py
This reverts commit abfc3d3.
…tor6

# Conflicts:
#	python/sglang/srt/managers/schedule_policy.py
#	python/sglang/srt/mem_cache/hi_mamba_radix_cache.py
Co-authored-by: hzh0425 <hzh0425@apache.org>
Co-authored-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
parasol-aser pushed a commit to parasol-aser/sglang that referenced this pull request Apr 11, 2026
Implements the HiCacheStorage v2 interface for the 3FS backend so that
hybrid models (Mamba/linear-attention, and in the future DSA) can offload
both KV pages and auxiliary per-pool state to 3FS via HybridCacheController.

- Introduce _Hf3fsPoolEngine: a per-pool bundle of (file, client list,
  executor, metadata client, rank namespace, is_zero_copy, skip_backup)
  so each registered host pool has its own 3FS file and metadata scope.
- Construct the KV engine in __init__ so v1 callers keep working unchanged.
- Implement register_mem_host_pool_v2 to lazily allocate auxiliary
  (MAMBA/...) engines with their own preallocated files, clients and
  metadata namespaces. Idempotent and order-agnostic.
- Implement batch_exists_v2 / batch_get_v2 / batch_set_v2 mirroring the
  HiCacheFile semantics, including ALL_PAGES and TRAILING_PAGES hit
  policies, min-across-pools final hit, and per-pool result dicts.
- Refactor _batch_get / _batch_set to take an engine argument so both
  v1 and v2 entry points share the same IO core.
- Key namespacing: auxiliary pools prefix the metadata key with the
  pool name, KV keeps the bare key for backwards compatibility. MHA
  zero-copy -k/-v suffixing remains strictly KV-scoped.
- Per-pool skip_backup so MLA rank>0 still skips KV but backs up MAMBA
  on every rank. Fix a pre-existing bug where skip_backup returned a
  scalar True instead of a per-key list.
- close() now iterates all engines; _engines is populated before the
  SIGTERM handler is installed.

Test plan:
- New test/registered/hicache/test_hicache_storage_3fs_hybrid.py uses the
  mock HF3FS client to cover: construction sanity, KV-only v2 fallback,
  ALL_PAGES and TRAILING_PAGES exists semantics, v2 set/get round-trip,
  MHA zero-copy + mamba interplay, MLA skip_backup KV-only scoping,
  partial-pool failure, and a no-pool error contract.
- Extended test_hicache_storage_3fs_backend.py with TestHf3fsBackendHybrid
  end-to-end test for a hybrid model, gated on model availability.

Scope: PoolName.KV + PoolName.MAMBA. DSA is deferred until a caller
exists (see PLAN.md §3 and Appendix B).

Tracking issue: sgl-project#22572
Reference PRs: sgl-project#21259, sgl-project#20457

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@huangtingwei9988
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@hzh0425
Copy link
Copy Markdown
Collaborator

hzh0425 commented Apr 12, 2026

@ykwd
Copy link
Copy Markdown
Contributor

ykwd commented Apr 13, 2026

Thanks for this work. We’ve run accuracy tests based on this PR, and everything looks good.

@huangtingwei9988
Copy link
Copy Markdown
Collaborator Author

Thanks for this work. We’ve run accuracy tests based on this PR, and everything looks good.

Thank you so much!! @ykwd

@huangtingwei9988
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

Comment on lines +1736 to +1737
if getattr(entry, "share_indices_with_anchor", False):
entry.host_pool.free(indices)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be done by the anchor pool already?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thank you—it is a bug. #22767

Comment on lines 433 to 441
def _page_backup(self, operation):
# Backup extra pools
if operation.pool_transfers:
self._resolve_shared_pool_transfers(operation)
results = self.storage_backend.batch_set_v2(operation.pool_transfers)
operation.pool_storage_result.update_extra_pool_hit_pages(results)

# Backup kv pools
super()._page_backup(operation)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like atomic. But HiCacheController._page_backup does not seem to be atomic, will it cause a mismatch bug in the future?

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No other comments, looks good.

@huangtingwei9988
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@xiezhq-hermann xiezhq-hermann merged commit e9d6b9e into sgl-project:main Apr 14, 2026
531 of 557 checks passed
pyc96 pushed a commit to pyc96/sglang that referenced this pull request Apr 14, 2026
…gl-project#21259)

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: hzh0425 <hzh0425@apache.org>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
Co-authored-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
@LJL36
Copy link
Copy Markdown

LJL36 commented Apr 15, 2026

I encountered a sporadic CUDA error: invalid argument when attaching L3 storage backend at runtime via the admin API on a TP=8 setup.

Launch command:

python -m sglang.launch_server \
  --model-path GLM-5.1-FP8 \
  --host 0.0.0.0 --port 8000 \
  --enable-metrics --tp 8 \
  --reasoning-parser glm45 --tool-call-parser glm47 \
  --page-size 64 --mem-fraction-static 0.85 \
  --enable-hierarchical-cache --hicache-size 60 \
  --hicache-mem-layout page_first_direct \
  --hicache-io-backend direct \
  --hicache-write-policy write_through \
  --admin-api-key secret_for_hicache \
  --skip-server-warmup

Then I attached L3 storage via the admin API. The scheduler crashed sporadically with:

Traceback (most recent call last):
  File ".../hiradix_cache.py", line 341, in attach_storage_backend
    self.cache_controller.attach_storage_backend(
  File ".../hybrid_cache_controller.py", line 208, in attach_storage_backend
    super().attach_storage_backend(
  File ".../cache_controller.py", line 478, in attach_storage_backend
    self.prefetch_tp_group = create_custom_parallel_group(
  File ".../parallel_state.py", line 2006, in create_custom_parallel_group
    torch.distributed.all_gather_object(gathered_configs, local_config)
  File ".../distributed_c10d.py", line 3076, in _object_to_tensor
    byte_tensor = torch.ByteTensor(byte_storage).to(device)
torch.AcceleratorError: CUDA error: invalid argument

This is intermittent — retrying the same operation usually succeeds. It seems like a race condition where TP ranks reach all_gather_object at different times during attach_storage_backend.

@huangtingwei9988
Copy link
Copy Markdown
Collaborator Author

I encountered a sporadic CUDA error: invalid argument when attaching L3 storage backend at runtime via the admin API on a TP=8 setup.

Launch command:

python -m sglang.launch_server \
  --model-path GLM-5.1-FP8 \
  --host 0.0.0.0 --port 8000 \
  --enable-metrics --tp 8 \
  --reasoning-parser glm45 --tool-call-parser glm47 \
  --page-size 64 --mem-fraction-static 0.85 \
  --enable-hierarchical-cache --hicache-size 60 \
  --hicache-mem-layout page_first_direct \
  --hicache-io-backend direct \
  --hicache-write-policy write_through \
  --admin-api-key secret_for_hicache \
  --skip-server-warmup

Then I attached L3 storage via the admin API. The scheduler crashed sporadically with:

Traceback (most recent call last):
  File ".../hiradix_cache.py", line 341, in attach_storage_backend
    self.cache_controller.attach_storage_backend(
  File ".../hybrid_cache_controller.py", line 208, in attach_storage_backend
    super().attach_storage_backend(
  File ".../cache_controller.py", line 478, in attach_storage_backend
    self.prefetch_tp_group = create_custom_parallel_group(
  File ".../parallel_state.py", line 2006, in create_custom_parallel_group
    torch.distributed.all_gather_object(gathered_configs, local_config)
  File ".../distributed_c10d.py", line 3076, in _object_to_tensor
    byte_tensor = torch.ByteTensor(byte_storage).to(device)
torch.AcceleratorError: CUDA error: invalid argument

This is intermittent — retrying the same operation usually succeeds. It seems like a race condition where TP ranks reach all_gather_object at different times during attach_storage_backend.

Thank you! we have also noticed this issue. Do you happen to have a reliable way to reproduce it? You can reach me on Slack: Tingwei Huang.

@LJL36
Copy link
Copy Markdown

LJL36 commented Apr 16, 2026

I encountered a sporadic CUDA error: invalid argument when attaching L3 storage backend at runtime via the admin API on a TP=8 setup.
Launch command:

python -m sglang.launch_server \
  --model-path GLM-5.1-FP8 \
  --host 0.0.0.0 --port 8000 \
  --enable-metrics --tp 8 \
  --reasoning-parser glm45 --tool-call-parser glm47 \
  --page-size 64 --mem-fraction-static 0.85 \
  --enable-hierarchical-cache --hicache-size 60 \
  --hicache-mem-layout page_first_direct \
  --hicache-io-backend direct \
  --hicache-write-policy write_through \
  --admin-api-key secret_for_hicache \
  --skip-server-warmup

Then I attached L3 storage via the admin API. The scheduler crashed sporadically with:

Traceback (most recent call last):
  File ".../hiradix_cache.py", line 341, in attach_storage_backend
    self.cache_controller.attach_storage_backend(
  File ".../hybrid_cache_controller.py", line 208, in attach_storage_backend
    super().attach_storage_backend(
  File ".../cache_controller.py", line 478, in attach_storage_backend
    self.prefetch_tp_group = create_custom_parallel_group(
  File ".../parallel_state.py", line 2006, in create_custom_parallel_group
    torch.distributed.all_gather_object(gathered_configs, local_config)
  File ".../distributed_c10d.py", line 3076, in _object_to_tensor
    byte_tensor = torch.ByteTensor(byte_storage).to(device)
torch.AcceleratorError: CUDA error: invalid argument

This is intermittent — retrying the same operation usually succeeds. It seems like a race condition where TP ranks reach all_gather_object at different times during attach_storage_backend.

Thank you! we have also noticed this issue. Do you happen to have a reliable way to reproduce it? You can reach me on Slack: Tingwei Huang.

i cannot reliably reproduce this

yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
…gl-project#21259)

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: hzh0425 <hzh0425@apache.org>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
Co-authored-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
if entry is self.anchor_entry:
continue
if getattr(entry, "share_indices_with_anchor", False):
entry.host_pool.free(indices)
Copy link
Copy Markdown
Collaborator

@whybeyoung whybeyoung Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

god shangming

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hicache Hierarchical Caching for SGLang high priority run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.