Skip to content

Nsa trtllm mla sparse fp8 support with Deepseek v3.2 NVFP4#18389

Merged
Fridge003 merged 26 commits intosgl-project:mainfrom
bytedance-iaas:nsa_trtllm_mla_fp8
Feb 16, 2026
Merged

Nsa trtllm mla sparse fp8 support with Deepseek v3.2 NVFP4#18389
Fridge003 merged 26 commits intosgl-project:mainfrom
bytedance-iaas:nsa_trtllm_mla_fp8

Conversation

@rainj-me
Copy link
Copy Markdown
Collaborator

@rainj-me rainj-me commented Feb 7, 2026

Motivation

#17655

  • support Deepseek v3.2 NVFP4 with trtllm mla sparse fp8 attention backend

Modifications

  • update the nsa backend to support trtllm sparse fp8 attention backend
  • update the deepseek v2 to make sure the cos_sin_cache pass to trtllm kernels

Accuracy Tests

GSM8K

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 200 --parallel 100 --port 30000
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.50it/s]
Accuracy: 0.960
Invalid: 0.000
Latency: 7.268 s
Output throughput: 2611.031 token/s

GPQA

 python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 128000 --repeat 8 --top-p 0.95 --temperature 1.0 --thinking-mode deepseek-v3
ChatCompletionSampler initialized with self.system_message=None self.temperature=1.0 self.max_tokens=128000 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'thinking': True}}
ChatCompletionSampler initialized with self.system_message=None self.temperature=1.0 self.max_tokens=128000 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'thinking': True}}
ChatCompletionSampler initialized with self.system_message=None self.temperature=1.0 self.max_tokens=128000 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'thinking': True}}
ChatCompletionSampler initialized with self.system_message=None self.temperature=1.0 self.max_tokens=128000 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'thinking': True}}
ChatCompletionSampler initialized with self.system_message=None self.temperature=1.0 self.max_tokens=128000 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'thinking': True}}
ChatCompletionSampler initialized with self.system_message=None self.temperature=1.0 self.max_tokens=128000 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'thinking': True}}
ChatCompletionSampler initialized with self.system_message=None self.temperature=1.0 self.max_tokens=128000 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'thinking': True}}
ChatCompletionSampler initialized with self.system_message=None self.temperature=1.0 self.max_tokens=128000 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'thinking': True}}
100%|███████████████████████████████████████████| 198/198 [07:11<00:00,  2.18s/it]
100%|███████████████████████████████████████████| 198/198 [07:43<00:00,  2.34s/it]
100%|███████████████████████████████████████████| 198/198 [07:53<00:00,  2.39s/it]
100%|███████████████████████████████████████████| 198/198 [08:25<00:00,  2.55s/it]
100%|███████████████████████████████████████████| 198/198 [08:46<00:00,  2.66s/it]
100%|███████████████████████████████████████████| 198/198 [08:47<00:00,  2.66s/it]
100%|███████████████████████████████████████████| 198/198 [08:52<00:00,  2.69s/it]
100%|███████████████████████████████████████████| 198/198 [15:52<00:00,  4.81s/it]
====================
Repeat: 8, mean: 0.818                         | 40/198 [15:52<1:10:59, 26.96s/it]
Scores: ['0.798', '0.823', '0.803', '0.818', '0.823', '0.813', '0.838', '0.828']
====================
[METRIC] gpqa_mean_score=0.8181818181818181 labels={"model": "/data02/models/DeepSeek-V3.2-NVFP4", "eval": "gpqa", "repeat": 8}
Writing report to /tmp/gpqa__data02_models_DeepSeek-V3.2-NVFP4.html
{'chars': np.float64(20040.570707070707), 'chars:std': np.float64(19597.30393311847), 'score:std': np.float64(0.37713443843625194), 'scores': ['0.798', '0.823', '0.803', '0.818', '0.823', '0.813', '0.838', '0.828'], 'mean_score': np.float64(0.8181818181818181)}
Writing results to /tmp/gpqa__data02_models_DeepSeek-V3.2-NVFP4.json


cat /tmp/gpqa__data02_models_DeepSeek-V3.2-NVFP4.json
{
  "chars": 20040.570707070707,
  "chars:std": 19597.30393311847,
  "score:std": 0.37713443843625194,
  "scores": [
    "0.798",
    "0.823",
    "0.803",
    "0.818",
    "0.823",
    "0.813",
    "0.838",
    "0.828"
  ],
  "mean_score": 0.8181818181818181

AIME25

nsa trtllm sparse attn backend, fp8 kv cache and MTP

nemo-run_1/0 ---------------------------------------- aime25 ----------------------------------------
nemo-run_1/0 evaluation_mode  | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1[avg-of-4] | 30          | 14424      | 751         | 91.67% ± 1.92%   | 0.83%
nemo-run_1/0 majority@4       | 30          | 14424      | 751         | 93.33%           | 0.00%
nemo-run_1/0 pass@4           | 30          | 14424      | 751         | 93.33%           | 0.00%

nsa flashmla_auto/flashmla_kv attn backend, fp8 kv cache and MTP

nemo-run_1/0 ---------------------------------------- aime25 ----------------------------------------
nemo-run_1/0 evaluation_mode  | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1[avg-of-4] | 30          | 14720      | 818         | 87.50% ± 3.19%   | 0.83%
nemo-run_1/0 majority@4       | 30          | 14720      | 818         | 90.00%           | 0.00%
nemo-run_1/0 pass@4           | 30          | 14720      | 818         | 90.00%           | 0.00%
nemo-run_1/0
nemo-run_1/0

Benchmarking and Profiling

# Prefill
SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2=1 UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 UCX_CUDA_IPC_ENABLE_MNNVL=1 TORCH_CUDA_ARCH_LIST=10.0 NVSHMEM_IB_ENABLE_IBGDA=0 NVSHMEM_ENABLE_NIC_PE_MAPPING=1 NVSHMEM_DISABLE_LOCAL_ONLY_PROXY=1 NVSHMEM_IB_GID_INDEX=0 NVSHMEM_HCA_LIST=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1  NCCL_IB_GID_INDEX=0 NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 NCCL_IB_DISABLE=0 NCCL_MNNVL_ENABLE=1 NCCL_CUMEM_ENABLE=1 NCCL_SOCKET_IFNAME=eth1 GLOO_SOCKET_IFNAME=eth1 SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 SGLANG_ENABLE_JIT_DEEPGEMM=0 python3 -m sglang.launch_server --model-path  /data02/models/DeepSeek-V3.2-NVFP4 --trust-remote-code --disaggregation-mode prefill --tp-size 4 --ep-size 4 --dp-size 4 --enable-dp-attention  --enable-dp-lm-head  --device cuda  --host 0.0.0.0 --port 28000 --disaggregation-bootstrap-port 8991 --mem-fraction-static 0.8 --moe-runner-backend flashinfer_trtllm --quantization modelopt_fp4 --chunked-prefill-size 131072  --disable-radix-cache  --watchdog-timeout 1200 --page-size 64  --load-balance-method round_robin --disaggregation-transfer-backend nixl --disaggregation-ib-device mlx5_0,mlx5_1,mlx5_4,mlx5_5 --kv-cache-dtype fp8_e4m3 --attention-backend nsa --nsa-prefill-backend trtllm --nsa-decode-backend trtllm --page-size 64

# Decode Rank 0
SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE=1 SGLANG_ENABLE_JIT_DEEPGEMM=1 SGLANG_ENABLE_SPEC_V2=1 UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1  UCX_CUDA_IPC_ENABLE_MNNVL=0 TORCH_CUDA_ARCH_LIST=10.0 NVSHMEM_IB_ENABLE_IBGDA=0 NVSHMEM_ENABLE_NIC_PE_MAPPING=1 NVSHMEM_DISABLE_LOCAL_ONLY_PROXY=1 NVSHMEM_IB_GID_INDEX=0 NVSHMEM_HCA_LIST=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=eth1 NCCL_IB_GID_INDEX=0 NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 NCCL_MNNVL_ENABLE=1 NCCL_CUMEM_ENABLE=1 NCCL_SOCKET_IFNAME=eth1 NCCL_SOCKET_FAMILY=AF_INET GLOO_SOCKET_IFNAME=eth1 SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 SGLANG_MOE_NVFP4_DISPATCH=1 SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=1024 python3 -m sglang.launch_server --model-path /data02/models/DeepSeek-V3.2-NVFP4 --trust-remote-code --dist-init-addr 192.168.0.1:21000 --nnodes 2 --node-rank 0 --disaggregation-mode decode --tp-size 8 --ep-size 8 --dp-size 8 --enable-dp-attention --enable-dp-lm-head --device cuda --host 0.0.0.0 --port 28000 --mem-fraction-static 0.7 --moe-runner-backend flashinfer_cutedsl --moe-a2a-backend deepep --deepep-mode low_latency --quantization modelopt_fp4 --chunked-prefill-size 131072 --page-size 64 --disable-radix-cache --max-running-requests 2048 --disaggregation-transfer-backend nixl --watchdog-timeout 1200 --disaggregation-ib-device mlx5_0,mlx5_1,mlx5_4,mlx5_5 --cuda-graph-max-bs 256 --kv-cache-dtype fp8_e4m3 --attention-backend nsa --nsa-prefill-backend trtllm --nsa-decode-backend trtllm  --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --speculative-moe-runner-backend deep_gemm --speculative-moe-a2a-backend deepep --speculative-attention-mode decode


# Decode Rank 1
SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE=1 SGLANG_ENABLE_JIT_DEEPGEMM=1 SGLANG_ENABLE_SPEC_V2=1 UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1  UCX_CUDA_IPC_ENABLE_MNNVL=0 TORCH_CUDA_ARCH_LIST=10.0 NVSHMEM_IB_ENABLE_IBGDA=0 NVSHMEM_ENABLE_NIC_PE_MAPPING=1 NVSHMEM_DISABLE_LOCAL_ONLY_PROXY=1 NVSHMEM_IB_GID_INDEX=0 NVSHMEM_HCA_LIST=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=eth1 NCCL_IB_GID_INDEX=0 NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1 NCCL_MNNVL_ENABLE=1 NCCL_CUMEM_ENABLE=1 NCCL_SOCKET_IFNAME=eth1 NCCL_SOCKET_FAMILY=AF_INET GLOO_SOCKET_IFNAME=eth1 SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 SGLANG_MOE_NVFP4_DISPATCH=1 SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=1024 python3 -m sglang.launch_server --model-path /data02/models/DeepSeek-V3.2-NVFP4 --trust-remote-code --dist-init-addr 192.168.0.1:21000 --nnodes 2 --node-rank 1 --disaggregation-mode decode --tp-size 8 --ep-size 8 --dp-size 8 --enable-dp-attention --enable-dp-lm-head --device cuda --host 0.0.0.0 --port 28000 --mem-fraction-static 0.7 --moe-runner-backend flashinfer_cutedsl --moe-a2a-backend deepep --deepep-mode low_latency --quantization modelopt_fp4 --chunked-prefill-size 131072 --page-size 64 --disable-radix-cache --max-running-requests 2048 --disaggregation-transfer-backend nixl --watchdog-timeout 1200 --disaggregation-ib-device mlx5_0,mlx5_1,mlx5_4,mlx5_5 --cuda-graph-max-bs 256 --kv-cache-dtype fp8_e4m3 --attention-backend nsa --nsa-prefill-backend trtllm --nsa-decode-backend trtllm --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --speculative-moe-runner-backend deep_gemm --speculative-moe-a2a-backend deepep --speculative-attention-mode decode

# mini-lb
python -m sglang_router.launch_router \
    --pd-disaggregation \
    --policy round_robin \
    --prefill http://192.168.0.3:28000 8991 \
    --decode http://192.168.0.1:28000 \
    --host 127.0.0.1 \
    --mini-lb \
    --port 30000

# Benchmark
python3 -m sglang.bench_serving --backend sglang-oai-chat --base-url http://127.0.0.1:30000 --model /data02/models/DeepSeek-V3.2-NVFP4  --dataset-name random --seed 5 --random-input-len 3500 --random-output-len 1500 --random-range-ratio 1.0 --num-prompts 4000 --max-concurrency 1920 --request-rate 15

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    15.0
Max request concurrency:                 1920
Successful requests:                     4000
Benchmark duration (s):                  290.56
Total input tokens:                      14000000
Total input text tokens:                 14000000
Total generated tokens:                  6000000
Total generated tokens (retokenized):    5916362
Request throughput (req/s):              13.77
Input token throughput (tok/s):          48183.62
Output token throughput (tok/s):         20650.12
Peak output token throughput (tok/s):    10218.00
Peak concurrent requests:                428
Total token throughput (tok/s):          68833.74
Concurrency:                             337.74
Accept length:                           2.92
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   24532.73
Median E2E Latency (ms):                 24520.04
P90 E2E Latency (ms):                    29282.55
P99 E2E Latency (ms):                    33215.50
---------------Time to First Token----------------
Mean TTFT (ms):                          4230.75
Median TTFT (ms):                        4125.91
P99 TTFT (ms):                           6867.95
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          13.54
Median TPOT (ms):                        13.64
P99 TPOT (ms):                           19.15
---------------Inter-Token Latency----------------
Mean ITL (ms):                           13.70
Median ITL (ms):                         11.19
P95 ITL (ms):                            28.08
P99 ITL (ms):                            45.33
Max ITL (ms):                            23377.74
==================================================

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @rainj-me, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the model serving capabilities by integrating the Deepseek v3.2 NVFP4 model with a TensorRT-LLM sparse FP8 attention backend. The changes streamline the attention mechanism's performance and memory efficiency by introducing dedicated FP8 quantization and RoPE application logic, and by optimizing KV cache storage for this specific configuration. This allows for more efficient inference with Deepseek v3.2 NVFP4 models.

Highlights

  • Deepseek v3.2 NVFP4 Support: Introduced comprehensive support for Deepseek v3.2 NVFP4 models, leveraging the TensorRT-LLM (TRT-LLM) Multi-Layer Attention (MLA) sparse FP8 attention backend.
  • NSA Backend Refactoring for TRT-LLM FP8: The Native Sparse Attention (NSA) backend has been significantly refactored to integrate TRT-LLM's FP8 quantization and RoPE application, specifically for decode operations. This includes a new _quantize_and_rope_for_fp8 function for fused RoPE and quantization.
  • Optimized KV Cache Management: Adjustments were made to the KV cache management within the memory pool to optimize for the FP8 path, including conditional separate quantization for K-cache and overriding KV cache dimensions when using TRT-LLM for both prefill and decode.
  • Deepseek v2 cos_sin_cache Handling: Ensured that the cos_sin_cache is correctly passed to TRT-LLM kernels for Deepseek v2 models, particularly for fused RoPE and quantization paths during decode and speculative operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/nsa_backend.py
    • Added data_type attribute to NSAAttentionBackend to store the KV cache data type.
    • Refactored forward_extend and forward_decode to delegate TRT-LLM specific logic to a new _forward_trtllm method.
    • Introduced _quantize_and_rope_for_fp8 function for fused RoPE application and FP8 quantization of query and key components.
    • Updated _forward_trtllm to handle FP8 quantization, KV cache saving, and dynamic query merging based on the FP8 path.
    • Removed redundant TRT-LLM branches from forward_extend and forward_decode.
    • Commented out a max_kv_len check in set_nsa_prefill_impl.
    • Moved _concat_mla_absorb_q_general to a global scope.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Added should_quantize_k_cache_separate parameter to MLAKVCache and NSAKVCache constructors.
    • Modified set_mla_kv_buffer to conditionally use separate quantization for K-cache based on should_quantize_k_cache_separate.
    • Configured NSAKVCache to disable separate K-cache quantization and override KV cache dimension when both prefill and decode backends are TRT-LLM.
  • python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
    • Passed prefill and decode attention backend configurations to the NSAKVCache initialization.
  • python/sglang/srt/models/deepseek_v2.py
    • Added nsa_attention_backend_impl attribute to DeepSeekV2Model.
    • Updated dispatch_attn_forward_method to set nsa_attention_backend_impl based on server arguments.
    • Extended _fuse_rope_for_trtllm_mla to cover NSA backend with TRT-LLM implementation and is_draft_extend mode for fused RoPE and quantization.
Activity
  • The author rainj-me has provided a clear motivation for supporting Deepseek v3.2 NVFP4 with TRT-LLM sparse FP8 attention.
  • Modifications include updating the NSA backend and ensuring cos_sin_cache is passed to TRT-LLM kernels for Deepseek v2.
  • Accuracy tests using benchmark/gsm8k/bench_sglang.py show an accuracy of 0.975 with a latency of 16.654s and output throughput of 1110.934 token/s.
  • Detailed benchmarking and profiling commands for both prefill and decode ranks have been provided, demonstrating the setup for the new backend.
  • The checklist indicates that code formatting, unit tests, documentation, accuracy/speed benchmarks, and code style guidelines have been followed.
  • A TODO item notes that accuracy issues with speculative decoding enabled still need to be fixed.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Deepseek v3.2 NVFP4 with the trtllm mla sparse fp8 attention backend. The changes involve updating the NSA backend to support this new configuration and ensuring cos_sin_cache is correctly passed to the trtllm kernels. The refactoring in nsa_backend.py to handle the trtllm backend at the beginning of forward_extend and forward_decode is a good improvement. I have a few suggestions to improve code clarity and fix a potential bug in deepseek_v2.py.

Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
@github-actions github-actions Bot added the blackwell SM100/SM120 label Feb 11, 2026
@Fridge003
Copy link
Copy Markdown
Collaborator

@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Feb 13, 2026
@rainj-me
Copy link
Copy Markdown
Collaborator Author

If trtllm backend is faster on fp8 kv cache, then we might modify the _set_default_nsa_backends function in server_args.py https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py#L1149

Let's bake it a little bit and do it in different PR.

@rainj-me
Copy link
Copy Markdown
Collaborator Author

/tag-run-ci-label

@rainj-me
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@rainj-me
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@rainj-me
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@rainj-me
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

1 similar comment
@rainj-me
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@Fridge003 Fridge003 merged commit 0ffd0a3 into sgl-project:main Feb 16, 2026
224 of 237 checks passed
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 deepseek documentation Improvements or additions to documentation high priority run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants