Skip to content

[Refactor] Deduplicate NSA utils.py into cp_utils.py for context parallel#22914

Merged
Fridge003 merged 7 commits intomainfrom
dedup-nsa-utils-cp-utils
Apr 20, 2026
Merged

[Refactor] Deduplicate NSA utils.py into cp_utils.py for context parallel#22914
Fridge003 merged 7 commits intomainfrom
dedup-nsa-utils-cp-utils

Conversation

@Fridge003
Copy link
Copy Markdown
Collaborator

@Fridge003 Fridge003 commented Apr 16, 2026

Summary

  • Removed ~270 lines of duplicated context-parallel utility functions from layers/attention/nsa/utils.py, consolidating them into layers/utils/cp_utils.py
  • Unified NSAContextParallelMetadata into ContextParallelMetadata (identical fields)
  • Merged nsa_cp_metadata field into attn_cp_metadata on ForwardBatch (both fields were never set simultaneously)
  • Extended cp_utils.py functions with round-robin split support and symmetric memory allocation
  • Renamed NSA-specific can_cp_splitcan_nsa_cp_split to avoid name collision with the generic version
  • Replaced prepare_input_dp_with_cp_dsa with prepare_context_parallel_metadata (which has better prefix_len handling)
  • Updated all callers: deepseek_v2.py, deepseek_nextn.py, nsa_indexer.py, forward_batch_info.py, schedule_batch.py, ascend_backend.py

Test plan

  • Run test/registered/cp/test_deepseek_v32_cp_single_node.py on H200
  • Verify round-robin-split CP mode works end-to-end with accuracy eval
  • CI should validate in-seq-split mode (the timeout was due to uncompiled DeepGEMM in the test env, not code changes)
  • Pre-commit lint checks pass

🤖 Generated with Claude Code

…llel

Remove duplicated context-parallel utility functions from
`layers/attention/nsa/utils.py` and consolidate them into
`layers/utils/cp_utils.py`. This eliminates ~270 lines of duplicate
code while preserving all functionality.

Key changes:
- Unify `NSAContextParallelMetadata` into `ContextParallelMetadata`
- Merge `nsa_cp_metadata` field into `attn_cp_metadata` on ForwardBatch
- Extend cp_utils functions with round-robin split and symmetric memory
- Rename NSA-specific `can_cp_split` to `can_nsa_cp_split`
- Replace `prepare_input_dp_with_cp_dsa` with `prepare_context_parallel_metadata`
- Update all callers (deepseek_v2, deepseek_nextn, nsa_indexer, etc.)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Fridge003
Copy link
Copy Markdown
Collaborator Author

/rerun-test test_deepseek_v32_cp_single_node.py

@github-actions
Copy link
Copy Markdown
Contributor

8-gpu-h200 (1 test): View workflow run

cd test/ && python3 registered/cp/test_deepseek_v32_cp_single_node.py

Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/models/deepseek_nextn.py Outdated
Comment thread python/sglang/srt/layers/attention/nsa/utils.py
@Fridge003
Copy link
Copy Markdown
Collaborator Author

/rerun-test test_deepseek_v32_cp_single_node.py

@github-actions
Copy link
Copy Markdown
Contributor

8-gpu-h200 (1 test): View workflow run

cd test/ && python3 registered/cp/test_deepseek_v32_cp_single_node.py

prepare_context_parallel_metadata now sums prefix_len into kv_len_prev/next,
so _get_topk_ragged_with_cp must not add (seq_lens_cpu - extend_seq_lens_cpu)
again, otherwise get_index_k_continuous reads past the block table and
triggers cudaErrorIllegalAddress once a request hits the radix prefix cache.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread python/sglang/srt/layers/attention/nsa/nsa_indexer.py Outdated
prepare_context_parallel_metadata was folding prefix_len into the int
fields, guarded by `len(seqs_len) == 1`. When the scheduler packs
multiple requests into one CP-extend (which happens under
max_running_requests=32 + speculative_attention_mode='prefill'), that
guard falls back to prefix_len=0 and the cached-prefix offset is
silently dropped. Under the companion 93ccfc7 fix that removed the
in-function prefix-add, _get_topk_ragged_with_cp then indexes into an
extend-only K range and the indexer's ke_offset gets truncated on every
prefix-cache hit, tanking GSM8K to ~0.52.

Restore the pre-refactor contract: metadata stores extend-only offsets,
and _get_topk_ragged_with_cp re-adds the batch-0 cached prefix via
(seq_lens_cpu - extend_seq_lens_cpu) — matching the batch-0 scope of
block_tables[0]. Measured TestDeepseekV32CPInSeqSplit 0.970 and
TestDeepseekV32CPRoundRobinSplit 0.975 on gsm8k/200 (H200, dp=2 cp=4 /
dp=1 cp=8, EAGLE spec).

kv_len_prev_tensor / kv_len_next_tensor (1-D (1,) shape) stay as-is for
the non-NSA qwen FA cache_seqlens path.
Comment thread python/sglang/srt/layers/utils/cp_utils.py Outdated
…non-NSA CP

Per review: `_get_topk_ragged_with_cp` only runs on the NSA model path, so
the previous commit's unconditional revert to extend-only int fields broke
non-NSA CP (e.g. qwen3-moe), where FlashAttention consumes kv_len_prev
directly as cache_seqlens and needs the prefix baked in.

Split the two contracts:
  - NSA CP (is_nsa_enable_prefill_cp()): keep kv_len_prev/next as extend-only;
    `_get_topk_ragged_with_cp` still re-adds the cached-prefix offset from
    (seq_lens_cpu - extend_seq_lens_cpu), which also handles the multi-request
    CP-extend packing case where the `len(seqs_len) == 1` guard falls back to
    prefix_len=0.
  - Non-NSA CP: restore the pre-refactor behavior that bakes prefix_len into
    kv_len_prev/next, so flash_attn_with_kvcache sees the full cache_seqlens.

Measured on H200 (gsm8k, 200 examples):
  - TestDeepseekV32CPInSeqSplit      0.970  (NSA path)
  - TestQwen330B                     0.970  (non-NSA FA path, threshold 0.85)
@Fridge003
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-8-gpu-h200

@Fridge003
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-4-gpu-h100

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-8-gpu-h200 to run independently (skipping dependencies). View workflow run

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-4-gpu-h100 to run independently (skipping dependencies). View workflow run

@Fridge003
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-deepep-8-gpu-h200

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-deepep-8-gpu-h200 to run independently (skipping dependencies). View workflow run

@Fridge003 Fridge003 merged commit c304d0d into main Apr 20, 2026
79 of 87 checks passed
@Fridge003 Fridge003 deleted the dedup-nsa-utils-cp-utils branch April 20, 2026 04:35
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
…llel (sgl-project#22914)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
kyx1999 pushed a commit to KMSorSMS/sglang that referenced this pull request Apr 27, 2026
…llel (sgl-project#22914)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@hnyls2002 hnyls2002 mentioned this pull request Apr 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants