[Model Runner V2] Multiple prompt logprobs support#39937
Conversation
Signed-off-by: yewentao256 <zhyanwentao@126.com>
There was a problem hiding this comment.
Code Review
This pull request enhances the prompt logprob computation logic to handle varying logprob request counts within a single batch. It introduces a mechanism to track the number of logprobs per request and utilizes the batch's maximum requested count during the chunked computation process. Feedback indicates that in mixed batches, requests currently receive the batch-wide maximum number of logprobs rather than their specific requested amount, which may cause assertion failures. A code suggestion was provided to slice the resulting tensors to match each request's individual requirements.
njhill
left a comment
There was a problem hiding this comment.
Thanks @yewentao256 this looks good to me, just minor simplifications
| @@ -17,13 +17,18 @@ def __init__(self, max_num_reqs: int): | |||
| self.max_num_reqs = max_num_reqs | |||
|
|
|||
| self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) | |||
There was a problem hiding this comment.
Do we still need this if we are introducing the counts?
There was a problem hiding this comment.
We still need uses_prompt_logprobs because prompt_logprobs=0 is a valid enabled case.
Co-authored-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
njhill
left a comment
There was a problem hiding this comment.
Thanks @yewentao256 just thought of one more small simplification
Co-authored-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Yifan <yzong@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Adrian <info@zzit.ch>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Purpose
Part of the #39337
Multiple prompt logprobs support
Test
VLLM_USE_V2_MODEL_RUNNER=1 pytest tests/v1/sample/test_logprobs.py -k prompt_logprobs_with_chunking_and_preemptionOriginnaly
__________________________________ test_prompt_logprobs_with_chunking_and_preemption ___________________________________ def test_prompt_logprobs_with_chunking_and_preemption(): """Test that prompt logprobs are correctly returned when using both chunked prefill and preemption. This test ensures that the num_prompt_logprobs tracking persists across preemptions and prefill chunks. """ # Create prompts that will trigger chunking and preemption prompts = [ "The following numbers of the sequence " + ", ".join(str(i) for i in range(10)) + " are:", "In one word, the capital of France is ", ] + [f"Tell me about the number {i}: " for i in range(32)] sampling_params = SamplingParams( temperature=0.0, max_tokens=40, min_tokens=20, prompt_logprobs=2, # Request prompt logprobs ) with VllmRunner( "Qwen/Qwen3-0.6B", max_model_len=512, enable_chunked_prefill=True, max_num_batched_tokens=48, # Force prefill chunking num_gpu_blocks_override=32, # Force preemptions disable_log_stats=False, gpu_memory_utilization=0.25, ) as vllm_model: metrics_before = vllm_model.llm.get_metrics() # Generate with prompt logprobs using generate_w_logprobs which # returns (output_ids, output_str, output_logprobs, prompt_logprobs) outputs = vllm_model.generate_w_logprobs( prompts, sampling_params=sampling_params, include_prompt_token_ids=True ) # Verify that all outputs have prompt logprobs for i, output in enumerate(outputs): _, _, _, prompt_token_ids, prompt_logprobs = output assert prompt_logprobs is not None and len(prompt_logprobs) > 0, ( f"Output {i} missing prompt logprobs" ) assert len(prompt_logprobs) == len(prompt_token_ids), ( "Unexpected number of prompt logprob positions" ) # Each position should have the requested number of logprobs for pos, logprobs_dict in enumerate(prompt_logprobs): if logprobs_dict is not None: # First token may be None > assert ( sampling_params.prompt_logprobs <= len(logprobs_dict) <= sampling_params.prompt_logprobs + 1 ), ( f"Output {i} position {pos} has {len(logprobs_dict)} " f"logprobs, expected {sampling_params.prompt_logprobs}" ) E AssertionError: Output 0 position 1 has 1 logprobs, expected 2 E assert 2 <= 1 E + where 2 = SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, t...mpt_logprobs=2, skip_special_tokens=True, spaces_between_special_tokens=True, structured_outputs=None, extra_args=None).prompt_logprobs E + and 1 = len({2701: Logprob(logprob=-10.656400680541992, rank=5307, decoded_token=' following')}) tests/v1/sample/test_logprobs.py:1216: AssertionErrorNow
======================= 1 passed, 52 deselected, 17 warnings in 14.63s =======================CC @WoosukKwon