Skip to content

Support batch size > 1 when enable CP#23269

Open
Shunkangz wants to merge 5 commits intosgl-project:mainfrom
Shunkangz:cp_multi_batch
Open

Support batch size > 1 when enable CP#23269
Shunkangz wants to merge 5 commits intosgl-project:mainfrom
Shunkangz:cp_multi_batch

Conversation

@Shunkangz
Copy link
Copy Markdown
Contributor

@Shunkangz Shunkangz commented Apr 20, 2026

Motivation

Enable batch size > 1 with context parallel.

Modifications

The main modification is the context_parallel_metadata for attention.

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Shunkangz
Copy link
Copy Markdown
Contributor Author

/tag-run-ci-label

@Shunkangz
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@Shunkangz
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Copy link
Copy Markdown
Collaborator

@kpham-sgl kpham-sgl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Shunkangz Thank you for the contribution, I left some reviews. There will be another round of more thorough review. We can discuss in the slack channel as well.
Couple of things I want to call out:

  • During development of MLA CP, I found out two notable bugs in current MHA CP impl. 1 is
    # MLA/MHA CP: prepare_mlp_sync_batch pads extend tokens up to
    # lcm(attn_tp_size, attn_cp_size), so cache_seqlens_cp can exceed
    # seq_lens_cpu.max(). Widen page_table by the pad delta to keep
    # FA3's causal reads in-bounds; widened columns index KV slot 0
    # (req_to_token is zero-init) and outputs for padding queries are
    # discarded downstream.
    if (
    self.attn_cp_size > 1
    and forward_batch.global_num_tokens_cpu is not None
    and forward_batch.extend_num_tokens is not None
    and forward_batch.extend_seq_lens_cpu is not None
    ):
    padded_extend = int(forward_batch.extend_num_tokens)
    real_extend = int(sum(forward_batch.extend_seq_lens_cpu))
    pad_delta = padded_extend - real_extend
    if pad_delta > 0:
    metadata.max_seq_len_k += pad_delta
    and 2 is
    # Derive prefix offset from unpadded CPU tensors. Both `seqs_len` and `extend_lens` are unpadded by the caller
    # Using the padded `kv_len` here would undercount `prefix_len` by the padding amount and shift the FA causal horizon.
  • We need MHA CP zigzag mode test for attn_cp_size == 4 and bs > 1

Comment thread python/sglang/srt/managers/schedule_policy.py
Comment thread python/sglang/srt/layers/utils/cp_utils.py
Comment thread python/sglang/srt/layers/utils/cp_utils.py Outdated
Comment thread python/sglang/srt/layers/utils/cp_utils.py Outdated
Comment thread python/sglang/srt/layers/utils/cp_utils.py
Comment thread python/sglang/srt/layers/utils/cp_utils.py
Comment thread python/sglang/srt/layers/utils/cp_utils.py
@Shunkangz
Copy link
Copy Markdown
Contributor Author

@Shunkangz Thank you for the contribution, I left some reviews. There will be another round of more thorough review. We can discuss in the slack channel as well. Couple of things I want to call out:

  • During development of MLA CP, I found out two notable bugs in current MHA CP impl. 1 is
    # MLA/MHA CP: prepare_mlp_sync_batch pads extend tokens up to
    # lcm(attn_tp_size, attn_cp_size), so cache_seqlens_cp can exceed
    # seq_lens_cpu.max(). Widen page_table by the pad delta to keep
    # FA3's causal reads in-bounds; widened columns index KV slot 0
    # (req_to_token is zero-init) and outputs for padding queries are
    # discarded downstream.
    if (
    self.attn_cp_size > 1
    and forward_batch.global_num_tokens_cpu is not None
    and forward_batch.extend_num_tokens is not None
    and forward_batch.extend_seq_lens_cpu is not None
    ):
    padded_extend = int(forward_batch.extend_num_tokens)
    real_extend = int(sum(forward_batch.extend_seq_lens_cpu))
    pad_delta = padded_extend - real_extend
    if pad_delta > 0:
    metadata.max_seq_len_k += pad_delta

    and 2 is
    # Derive prefix offset from unpadded CPU tensors. Both `seqs_len` and `extend_lens` are unpadded by the caller
    # Using the padded `kv_len` here would undercount `prefix_len` by the padding amount and shift the FA causal horizon.
  • We need MHA CP zigzag mode test for attn_cp_size == 4 and bs > 1

Thank you for pointing this out. In this PR, I only want to support the bs > 1 with GQA model. For MLA CP part, I left it as original implementation. I believe that the MLA CP should be refactored and aligned with our existing logic such as args, layer communicator and so on.

@kpham-sgl
Copy link
Copy Markdown
Collaborator

@Shunkangz Thank you for the contribution, I left some reviews. There will be another round of more thorough review. We can discuss in the slack channel as well. Couple of things I want to call out:

  • During development of MLA CP, I found out two notable bugs in current MHA CP impl. 1 is

    # MLA/MHA CP: prepare_mlp_sync_batch pads extend tokens up to
    # lcm(attn_tp_size, attn_cp_size), so cache_seqlens_cp can exceed
    # seq_lens_cpu.max(). Widen page_table by the pad delta to keep
    # FA3's causal reads in-bounds; widened columns index KV slot 0
    # (req_to_token is zero-init) and outputs for padding queries are
    # discarded downstream.
    if (
    self.attn_cp_size > 1
    and forward_batch.global_num_tokens_cpu is not None
    and forward_batch.extend_num_tokens is not None
    and forward_batch.extend_seq_lens_cpu is not None
    ):
    padded_extend = int(forward_batch.extend_num_tokens)
    real_extend = int(sum(forward_batch.extend_seq_lens_cpu))
    pad_delta = padded_extend - real_extend
    if pad_delta > 0:
    metadata.max_seq_len_k += pad_delta

    and 2 is

    # Derive prefix offset from unpadded CPU tensors. Both `seqs_len` and `extend_lens` are unpadded by the caller
    # Using the padded `kv_len` here would undercount `prefix_len` by the padding amount and shift the FA causal horizon.

  • We need MHA CP zigzag mode test for attn_cp_size == 4 and bs > 1

Thank you for pointing this out. In this PR, I only want to support the bs > 1 with GQA model. For MLA CP part, I left it as original implementation. I believe that the MLA CP should be refactored and aligned with our existing logic such as args, layer communicator and so on.

Ah sorry I should be clearer here:

  1. is an issue about the padding happening in prepare_mlp_sync_batch causing cache_seqlens_cp to go over seq_lens_cpu.max(). This bug occur also for MHA CP and will likely impact this PR.
  2. is an issue about padded kv_len messing up metadata computation. I think we agree on this already Support batch size > 1 when enable CP #23269 (comment)

@Shunkangz
Copy link
Copy Markdown
Contributor Author

@Shunkangz Thank you for the contribution, I left some reviews. There will be another round of more thorough review. We can discuss in the slack channel as well. Couple of things I want to call out:

  • During development of MLA CP, I found out two notable bugs in current MHA CP impl. 1 is

    # MLA/MHA CP: prepare_mlp_sync_batch pads extend tokens up to
    # lcm(attn_tp_size, attn_cp_size), so cache_seqlens_cp can exceed
    # seq_lens_cpu.max(). Widen page_table by the pad delta to keep
    # FA3's causal reads in-bounds; widened columns index KV slot 0
    # (req_to_token is zero-init) and outputs for padding queries are
    # discarded downstream.
    if (
    self.attn_cp_size > 1
    and forward_batch.global_num_tokens_cpu is not None
    and forward_batch.extend_num_tokens is not None
    and forward_batch.extend_seq_lens_cpu is not None
    ):
    padded_extend = int(forward_batch.extend_num_tokens)
    real_extend = int(sum(forward_batch.extend_seq_lens_cpu))
    pad_delta = padded_extend - real_extend
    if pad_delta > 0:
    metadata.max_seq_len_k += pad_delta

    and 2 is

    # Derive prefix offset from unpadded CPU tensors. Both `seqs_len` and `extend_lens` are unpadded by the caller
    # Using the padded `kv_len` here would undercount `prefix_len` by the padding amount and shift the FA causal horizon.

  • We need MHA CP zigzag mode test for attn_cp_size == 4 and bs > 1

Thank you for pointing this out. In this PR, I only want to support the bs > 1 with GQA model. For MLA CP part, I left it as original implementation. I believe that the MLA CP should be refactored and aligned with our existing logic such as args, layer communicator and so on.

Ah sorry I should be clearer here:

  1. is an issue about the padding happening in prepare_mlp_sync_batch causing cache_seqlens_cp to go over seq_lens_cpu.max(). This bug occur also for MHA CP and will likely impact this PR.
  2. is an issue about padded kv_len messing up metadata computation. I think we agree on this already Support batch size > 1 when enable CP #23269 (comment)

For 1, let's discuss in details through slack. For 2, I think that the existing TestContextParallelMetadata already cover this. Can you confirm it?

@kpham-sgl
Copy link
Copy Markdown
Collaborator

kpham-sgl commented May 6, 2026

@Shunkangz Thank you for the contribution, I left some reviews. There will be another round of more thorough review. We can discuss in the slack channel as well. Couple of things I want to call out:

  • During development of MLA CP, I found out two notable bugs in current MHA CP impl. 1 is

    # MLA/MHA CP: prepare_mlp_sync_batch pads extend tokens up to
    # lcm(attn_tp_size, attn_cp_size), so cache_seqlens_cp can exceed
    # seq_lens_cpu.max(). Widen page_table by the pad delta to keep
    # FA3's causal reads in-bounds; widened columns index KV slot 0
    # (req_to_token is zero-init) and outputs for padding queries are
    # discarded downstream.
    if (
    self.attn_cp_size > 1
    and forward_batch.global_num_tokens_cpu is not None
    and forward_batch.extend_num_tokens is not None
    and forward_batch.extend_seq_lens_cpu is not None
    ):
    padded_extend = int(forward_batch.extend_num_tokens)
    real_extend = int(sum(forward_batch.extend_seq_lens_cpu))
    pad_delta = padded_extend - real_extend
    if pad_delta > 0:
    metadata.max_seq_len_k += pad_delta

    and 2 is

    # Derive prefix offset from unpadded CPU tensors. Both `seqs_len` and `extend_lens` are unpadded by the caller
    # Using the padded `kv_len` here would undercount `prefix_len` by the padding amount and shift the FA causal horizon.

  • We need MHA CP zigzag mode test for attn_cp_size == 4 and bs > 1

Thank you for pointing this out. In this PR, I only want to support the bs > 1 with GQA model. For MLA CP part, I left it as original implementation. I believe that the MLA CP should be refactored and aligned with our existing logic such as args, layer communicator and so on.

Ah sorry I should be clearer here:

  1. is an issue about the padding happening in prepare_mlp_sync_batch causing cache_seqlens_cp to go over seq_lens_cpu.max(). This bug occur also for MHA CP and will likely impact this PR.
  2. is an issue about padded kv_len messing up metadata computation. I think we agree on this already Support batch size > 1 when enable CP #23269 (comment)

For 1, let's discuss in details through slack. For 2, I think that the existing TestContextParallelMetadata already cover this. Can you confirm it?

Yes let's discuss 1 further in slack tomorrow. For 2, sorry what test is this?

@Shunkangz
Copy link
Copy Markdown
Contributor Author

  • We need MHA CP zigzag mode test for attn_cp_size == 4 and bs > 1

I mean the TestContextParallelMetadata might already cover this.

@Shunkangz
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

1 similar comment
@Shunkangz
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants