Skip to content

[BugFix] Fix assert batch_descriptor.num_tokens == num_tokens_padded#30173

Merged
tlrmchlsmth merged 4 commits intovllm-project:mainfrom
neuralmagic:lwilkinson/fix-dp-assert
Dec 9, 2025
Merged

[BugFix] Fix assert batch_descriptor.num_tokens == num_tokens_padded#30173
tlrmchlsmth merged 4 commits intovllm-project:mainfrom
neuralmagic:lwilkinson/fix-dp-assert

Conversation

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson commented Dec 6, 2025

There's a bug with FULL_DECODE_ONLY (FULL_AND_PIECEWISE works fine) with DP. There is an edge case where one rank runs eager but all other ranks want to run with cudagraphs, so now we synchronize the cudagraph mode each rank wants to run as across all ranks. Since currently PIECEWISE can be treated as eager (valid in all the same situations) it is sufficient to just disable full cudagraphs if all ranks want to; we make want to pass an explicit list of valid modes in the future.

FIXES: #28579 (comment)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

)

return (should_ubatch, num_tokens_after_padding)
return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Align coordinate_batch_across_dp unpacking with new return

coordinate_batch_across_dp now returns three values including the synchronized cudagraph mode (return at line 240), but callers such as set_forward_context in forward_context.py (around lines 295–300) and eagle._pad_batch_across_dp in v1/spec_decode/eagle.py (around lines 1261–1269) still unpack only two items. In multi-DP runs where these paths invoke coordinate_batch_across_dp, Python will raise ValueError: too many values to unpack before padding or execution begins, breaking DP execution for forward contexts and EAGLE. Callers need to accept the third element or the function must preserve the previous 2-tuple interface.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug related to CUDA graph mode synchronization in a data parallel setup. The changes ensure that all ranks agree on a common CUDA graph mode by communicating their preferred mode and selecting the minimum one. This prevents assertion failures when different ranks attempt to use incompatible modes (e.g., one eager, others with cudagraphs). The implementation correctly passes the cudagraph mode during the all-reduce synchronization and uses the synchronized mode to make dispatching decisions. The logic appears sound and effectively fixes the described issue. I have one suggestion to improve code maintainability by replacing magic numbers with named constants.

Comment on lines +49 to +54
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
tensor[4][dp_rank] = cudagraph_mode
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of magic numbers 0, 1, 2, 3, 4 for indexing into the tensor makes the code hard to read and maintain. It's not immediately clear what each index represents without looking at the surrounding code or comments. This pattern is also present in _post_process_cudagraph_mode with tensor[4, :]. This can lead to bugs if the order or size of the tensor changes.

I recommend defining these indices as constants at the module level, for example, using an Enum. This would make the code self-documenting and less error-prone across all functions that use this tensor (_run_ar, _post_process_ubatch, _post_process_dp_padding, _post_process_cudagraph_mode).

For example:

from enum import IntEnum

class DPSync(IntEnum):
    ORIG_NUM_TOKENS_PER_UBATCH = 0
    PADDED_NUM_TOKENS_PER_UBATCH = 1
    SHOULD_UBATCH = 2
    SHOULD_DP_PAD = 3
    CUDAGRAPH_MODE = 4
    TENSOR_SIZE = 5

Then you could use tensor[DPSync.CUDAGRAPH_MODE] instead of tensor[4].

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 6, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@hjjq
Copy link
Copy Markdown
Contributor

hjjq commented Dec 8, 2025

Thanks @LucasWilkinson ! The error is gone for me.

Copy link
Copy Markdown
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix, @LucasWilkinson!

Copy link
Copy Markdown
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, LGTM

@github-project-automation github-project-automation bot moved this to In review in NVIDIA Dec 9, 2025
@tlrmchlsmth tlrmchlsmth merged commit 56037df into vllm-project:main Dec 9, 2025
47 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Dec 9, 2025
yiz-liu pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 20, 2026
…es in dp ranks (#6011)

### What this PR does / why we need it?
This PR aims to fix the issue that using A2 + AIV will hang due to the
fact that HCCL does not support eager/graph mode communication. To
handle it, following vllm-project/vllm#30173, we
introduce `synced_cudagraph_mode` to enable all ranks to know the
minimum mode across ranks. Main changes are described below:
1. `execute_model` now performs "dispatch -> sync -> re-dispatch" just
as `_dummy_run`
2. `_sync_metadata_across_dp` now receives `cudagraph_mode` from all
ranks and returns `synced_cudagraph_mode` to all ranks
3. Re-dispatch steps in both `execute_model` and `_dummy_run` include
`disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value` so
that when it is true, no FULL will be dispatched

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

---------

Signed-off-by: Zetong Li <slippersss@126.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
vllm-project#30173)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…es in dp ranks (vllm-project#6011)

### What this PR does / why we need it?
This PR aims to fix the issue that using A2 + AIV will hang due to the
fact that HCCL does not support eager/graph mode communication. To
handle it, following vllm-project/vllm#30173, we
introduce `synced_cudagraph_mode` to enable all ranks to know the
minimum mode across ranks. Main changes are described below:
1. `execute_model` now performs "dispatch -> sync -> re-dispatch" just
as `_dummy_run`
2. `_sync_metadata_across_dp` now receives `cudagraph_mode` from all
ranks and returns `synced_cudagraph_mode` to all ranks
3. Re-dispatch steps in both `execute_model` and `_dummy_run` include
`disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value` so
that when it is true, no FULL will be dispatched

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

---------

Signed-off-by: Zetong Li <slippersss@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants