[V1][Spec Decode] Fix greedy temperature detection after sampler refactor#27077
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a critical bug in speculative decoding that could lead to division-by-zero errors. The issue stemmed from an incomplete migration of the greedy sampling temperature representation from -1.0 to 0.0 in a previous refactor. Your changes correctly propagate this update across all affected components:
- In
rejection_sampler.py,GREEDY_TEMPERATUREis correctly updated to0. - In
eagle.py, the check for greedy sampling is updated to use an epsilon comparison, which is more robust and consistent with the rest of the codebase. - In
tpu_input_batch.py, the greedy temperature is set to0.0, aligning it with the GPU implementation.
These changes are well-targeted, correct, and crucial for fixing the reported crashes and incorrect behavior. The pull request is well-documented and the solution is sound.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR.
c5669f8 to
bacd60b
Compare
|
Thank you for the detailed review! You're absolutely correct - I've now fixed the TPU sampler division by zero issue. Changes MadeUpdated the commit to include:
This follows the same pattern as The fix ensures that:
All ruff checks pass. Let me know if you'd like any other changes! |
fb38090 to
3afb34a
Compare
njhill
left a comment
There was a problem hiding this comment.
Thanks @Pradyun92 for catching and fixing this. The fixes look great.
Could you rebase on latest main? That should hopefully fix the doc build error.
…ctor ## Issue Commit a676e66 changed greedy sampling temperature from -1.0 to 0.0, but several components checking for temperature == -1 were not updated: - TPU input batch (tpu_input_batch.py) still using -1.0 - Eagle's unused function (eagle.py) with hardcoded temperature == -1 - Rejection sampler (rejection_sampler.py) GREEDY_TEMPERATURE constant = -1 - TPU sampler (tpu/sampler.py) missing epsilon guard causing division by zero ## Impact Affects ALL models using speculative decoding with rejection sampling (Eagle/Eagle3, deepseek_mtp, ernie_mtp, qwen3_next_mtp, longcat_flash_mtp): - Division by zero in rejection sampler and TPU sampler - NaN/Inf logits - Empty generated_token_ids - Negative acceptance values (len([]) - 1 = -1) - Prometheus metrics crash ## Solution Complete the temperature migration consistently: 1. tpu_input_batch.py: Change temperature -1.0 → 0.0 2. eagle.py: Use epsilon comparison (temperature < _SAMPLING_EPS) 3. rejection_sampler.py: Change GREEDY_TEMPERATURE constant -1 → 0 4. tpu/sampler.py: Add epsilon guard before division 5. tpu/metadata.py: Add all_random property for sampler Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com>
Issue: TPU sampler and Eagle code had two separate but related issues: 1. TPU sampler divides by zero for greedy requests (temperature=0.0) 2. Eagle code triggers mypy type errors due to missing None check Root Cause: - TPU sampler's apply_temperature() method lacks epsilon guard to prevent division by zero when temperature=0.0 (greedy sampling) - Eagle's compute_probs_and_sample_next_token() uses temperature without asserting it's not None, causing mypy type errors Impact: - TPU: Division by zero produces NaN/Inf logits, breaking speculative decoding on TPU platforms for all models using Eagle/rejection sampling - Eagle: mypy type checking failures prevent pre-commit hooks from passing Fix: 1. TPU Sampler (vllm/v1/sample/tpu/sampler.py): - Add all_random parameter to apply_temperature() method - Add epsilon guard: if not all_random: temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) - Update call site to pass sampling_metadata.all_random 2. TPU Metadata (vllm/v1/sample/tpu/metadata.py): - Add all_random property to TPUSupportedSamplingMetadata - Populate all_random from input_batch in from_input_batch() 3. Eagle (vllm/v1/spec_decode/eagle.py): - Add assert sampling_metadata.temperature is not None after all_greedy early return - Matches sampler.py pattern (line 162) for type safety Files Modified: - vllm/v1/sample/tpu/sampler.py: Epsilon guard in apply_temperature() - vllm/v1/sample/tpu/metadata.py: Added all_random property - vllm/v1/spec_decode/eagle.py: Added temperature None assertion - CLAUDE.md: Updated modification vllm-project#11 to document fixes This addresses PR vllm-project#27077 reviewer feedback and resolves mypy type errors. Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com>
3afb34a to
fbf572f
Compare
|
Side note is that we clearly have some CI gaps here. |
…ctor (vllm-project#27077) Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
…ctor (vllm-project#27077) Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
…ctor (vllm-project#27077) Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
…ctor (vllm-project#27077) Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…ctor (vllm-project#27077) Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…ctor (vllm-project#27077) Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
…ctor (vllm-project#27077) Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
…ctor (vllm-project#27077) Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
### What this PR does / why we need it? fix greedy temperature detection from vllm-project/vllm#27077 - vLLM version: release/v0.13.0 - vLLM main: vllm-project/vllm@81786c8 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com>
### What this PR does / why we need it? fix greedy temperature detection from vllm-project/vllm#27077 - vLLM version: release/v0.13.0 - vLLM main: vllm-project/vllm@81786c8 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
### What this PR does / why we need it? fix greedy temperature detection from vllm-project/vllm#27077 - vLLM version: release/v0.13.0 - vLLM main: vllm-project/vllm@81786c8 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com>
### What this PR does / why we need it? fix greedy temperature detection from vllm-project/vllm#27077 - vLLM version: release/v0.13.0 - vLLM main: vllm-project/vllm@81786c8 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Description
Fixes division by zero and negative acceptance values in speculative decoding caused by incomplete temperature migration in commit a676e66.
Issue
Commit a676e66 changed greedy sampling temperature from
-1.0to0.0ingpu_input_batch.py, but several components still checking fortemperature == -1were not updated, causing:generated_token_idslen([]) - 1 = -1)Root Cause
Multiple components were not updated after the temperature refactor:
tpu_input_batch.py): Still using-1.0sentinel valueeagle.py): Hardcodedtemperature == -1check in unused functionrejection_sampler.py):GREEDY_TEMPERATUREconstant set to-1The rejection sampler's
expand_batch_to_tokens()function replacesGREEDY_TEMPERATURE(-1) with 1.0 before division, but when actual temperature is 0.0, the replacement doesn't happen, causing division by zero.Impact
Affects ALL models using speculative decoding with rejection sampling:
Solution
Complete the temperature migration consistently across all components:
temperature = -1.0→0.0for consistency with GPU_SAMPLING_EPSfromsampler.pytemperature < _SAMPLING_EPSinstead of== -1if not all_random:check before protection (consistent withsampler.py)GREEDY_TEMPERATUREconstant from-1→0Testing
Related