Skip to content

[ModelRunnerV2] Support prompt embeds#42963

Open
gcanlin wants to merge 6 commits into
vllm-project:mainfrom
gcanlin:prompt-embeds
Open

[ModelRunnerV2] Support prompt embeds#42963
gcanlin wants to merge 6 commits into
vllm-project:mainfrom
gcanlin:prompt-embeds

Conversation

@gcanlin

@gcanlin gcanlin commented May 18, 2026

Copy link
Copy Markdown
Contributor

Purpose

Support prompt embeds for ModelRunnerV2.

Test Plan

 VLLM_USE_V2_MODEL_RUNNER=1 pytest -sv   tests/basic_correctness/test_basic_correctness.py::test_models   -k "True-uni or True-mp"

Before

E       pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
E         Value error, VLLM_USE_V2_MODEL_RUNNER does not yet support: prompt embeds [type=value_error, input_value=ArgsKwargs((), {'model_co... 'shutdown_timeout': 0}), input_type=ArgsKwargs]
E           For further information visit https://errors.pydantic.dev/2.13/v/value_error

vllm/engine/arg_utils.py:2171: ValidationError
============================================================ warnings summary ============================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

../.venv/lib/python3.12/site-packages/torch/jit/_script.py:365: 14 warnings
  /root/vllm-workspace/.venv/lib/python3.12/site-packages/torch/jit/_script.py:365: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================== short test summary info =========================================================
FAILED tests/basic_correctness/test_basic_correctness.py::test_models[True-uni-True-False-5-FLASH_ATTN-hmellor/tiny-random-Gemma2ForCausalLM] - pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
FAILED tests/basic_correctness/test_basic_correctness.py::test_models[True-uni-True-False-5-FLASH_ATTN-meta-llama/Llama-3.2-1B-Instruct] - pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
FAILED tests/basic_correctness/test_basic_correctness.py::test_models[True-uni-False-False-5-FLASH_ATTN-hmellor/tiny-random-Gemma2ForCausalLM] - pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
FAILED tests/basic_correctness/test_basic_correctness.py::test_models[True-uni-False-False-5-FLASH_ATTN-meta-llama/Llama-3.2-1B-Instruct] - pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
FAILED tests/basic_correctness/test_basic_correctness.py::test_models[True-mp-True-False-5-FLASH_ATTN-hmellor/tiny-random-Gemma2ForCausalLM] - pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
FAILED tests/basic_correctness/test_basic_correctness.py::test_models[True-mp-True-False-5-FLASH_ATTN-meta-llama/Llama-3.2-1B-Instruct] - pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
FAILED tests/basic_correctness/test_basic_correctness.py::test_models[True-mp-False-False-5-FLASH_ATTN-hmellor/tiny-random-Gemma2ForCausalLM] - pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
FAILED tests/basic_correctness/test_basic_correctness.py::test_models[True-mp-False-False-5-FLASH_ATTN-meta-llama/Llama-3.2-1B-Instruct] - pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
======================================== 8 failed, 8 deselected, 16 warnings in 237.83s (0:03:57) ========================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

After

======================================== 8 passed, 8 deselected, 16 warnings in 443.73s (0:07:23) ========================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute
vllm serve --enable-prompt-embeds
"""Smoke test for prompt_embeds over the OpenAI-compatible HTTP server.

Usage:
    # Terminal 1 (server): see vllm serve command in the chat.
    # Terminal 2:
    .venv/bin/python test_embeds_serve.py
"""

import io

import openai
import pybase64 as base64
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "/root/.cache/modelscope/hub/models/Qwen/Qwen3-0___6B"
SERVED_NAME = "Qwen/Qwen3-0.6B"  # must match --served-model-name
BASE_URL = "http://localhost:8000/v1"

PROMPT = "The capital of France is"


def to_b64_embed(tensor: torch.Tensor) -> str:
    buf = io.BytesIO()
    # torch.save is what the server expects (it calls torch.load on the bytes).
    torch.save(tensor, buf)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


def main() -> None:
    tok = AutoTokenizer.from_pretrained(MODEL_PATH)
    hf = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16)
    with torch.no_grad():
        ids = tok(PROMPT, return_tensors="pt").input_ids
        embeds = (
            hf.get_input_embeddings()(ids)
            .squeeze(0)
            .to(torch.bfloat16)
            .cpu()
            .contiguous()
        )
    del hf

    encoded = to_b64_embed(embeds)
    client = openai.OpenAI(base_url=BASE_URL, api_key="EMPTY")

    # Case 1: prompt_embeds only.
    out = client.completions.create(
        model=SERVED_NAME,
        prompt=None,  # leave empty so the server falls through to prompt_embeds
        max_tokens=16,
        temperature=0.0,
        extra_body={"prompt_embeds": encoded},
    )
    print(f"[prompt_embeds] {out.choices[0].text!r}")

    # Case 2: same prompt via text path, for sanity comparison.
    out_text = client.completions.create(
        model=SERVED_NAME,
        prompt=PROMPT,
        max_tokens=16,
        temperature=0.0,
    )
    print(f"[text       ] {out_text.choices[0].text!r}")

    if out.choices[0].text == out_text.choices[0].text:
        print("MATCH: prompt_embeds output equals text output")
    else:
        print("DIVERGE: outputs differ (expected only if tokenizer/embedding "
              "scaling differs from raw lookup)")


if __name__ == "__main__":
    main()

Test Result

 python test_embeds_serve.py
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████| 311/311 [00:00<00:00, 6652.97it/s]
[prompt_embeds] ' Paris. The capital of France is also the capital of the Republic of France.'
[text       ] ' Paris. The capital of France is also the capital of the Republic of France.'
MATCH: prompt_embeds output equals text output

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin

gcanlin commented May 18, 2026

Copy link
Copy Markdown
Contributor Author

@yewentao256 @njhill Could you please take a look? Thx!

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for prompt embeddings in the V1 GPU worker. Key changes include updating the model runner to handle prompt_embeds during request addition and prompt length calculation, and modifying the model state to store and apply these embeddings during the model execution phase. Additionally, the logic for preparing input embeddings was refactored to accommodate both multi-modal inputs and prompt embeddings. A high-severity issue was identified in the _remove_request method, where the current order of operations could lead to a race condition if a request index is reused by a concurrent process before its associated model state is fully cleared.

Comment thread vllm/v1/worker/gpu/model_runner.py Outdated
Signed-off-by: gcanlin <canlinguosdu@gmail.com>

@yewentao256 yewentao256 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Could you add description of what specific issue solved for this PR?

E.g

VLLM_USE_V2_MODEL_RUNNER=1 pytest tests/basic_correctness/test_cumem.py::test_deep_sleep

Originally

(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]   File "/home/yewentao256/vllm-source/vllm/v1/executor/uniproc_executor.py", line 93, in collective_rpc
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]     result = run_method(self.driver_worker, method, args, kwargs)
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]   File "/home/yewentao256/vllm-source/vllm/v1/serial_utils.py", line 510, in run_method
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]     return func(*args, **kwargs)
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]   File "/home/yewentao256/vllm-source/vllm/v1/worker/gpu_worker.py", line 351, in reload_weights
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]     self.model_runner.reload_weights(*args, **kwargs)
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360] AttributeError: 'GPUModelRunner' object has no attribute 'reload_weights'
Now

======================================== 1 passed, 17 warnings in 45.19s =======================================

@gcanlin

gcanlin commented May 18, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for the work! Could you add description of what specific issue solved for this PR?

E.g

VLLM_USE_V2_MODEL_RUNNER=1 pytest tests/basic_correctness/test_cumem.py::test_deep_sleep

Originally

(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]   File "/home/yewentao256/vllm-source/vllm/v1/executor/uniproc_executor.py", line 93, in collective_rpc
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]     result = run_method(self.driver_worker, method, args, kwargs)
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]   File "/home/yewentao256/vllm-source/vllm/v1/serial_utils.py", line 510, in run_method
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]     return func(*args, **kwargs)
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]   File "/home/yewentao256/vllm-source/vllm/v1/worker/gpu_worker.py", line 351, in reload_weights
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]     self.model_runner.reload_weights(*args, **kwargs)
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=2663266) ERROR 05-14 19:13:43 [core.py:1360] AttributeError: 'GPUModelRunner' object has no attribute 'reload_weights'
Now

======================================== 1 passed, 17 warnings in 45.19s =======================================

Sure. Done now.

@yewentao256 yewentao256 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Could you also simplify the PR diff?

Comment thread vllm/v1/worker/gpu/model_states/default.py
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Canlin Guo <961750412@qq.com>
@gcanlin

gcanlin commented May 19, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for the work! Could you also simplify the PR diff?

Sure. Thanks for cleaning!

Signed-off-by: gcanlin <canlinguosdu@gmail.com>

@gcanlin gcanlin May 19, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would renaming this method to prepare_inputs_embeds be better so that we can have less intrusive code for model runner backbone? Then this method will include mm embed and prompt embed. cc @WoosukKwon

Signed-off-by: gcanlin <canlinguosdu@gmail.com>

@yewentao256 yewentao256 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            request = self.requests[req_id]
            if request.prompt_token_ids is None:
                # Prompt logprobs is incompatible with prompt embeddings
                continue

We have this in v1, should we add it as well?

@njhill

njhill commented May 20, 2026

Copy link
Copy Markdown
Member

Thanks @gcanlin

But are we sure we want to support this in MRV2? We should not just port over everything blindly, we would like to deprecate as much as possible.

@njhill njhill added the v2 label May 20, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin

gcanlin commented May 21, 2026

Copy link
Copy Markdown
Contributor Author
            request = self.requests[req_id]
            if request.prompt_token_ids is None:
                # Prompt logprobs is incompatible with prompt embeddings
                continue

We have this in v1, should we add it as well?

Yes. I add the guard before add_request.

@gcanlin

gcanlin commented May 21, 2026

Copy link
Copy Markdown
Contributor Author

Thanks @gcanlin

But are we sure we want to support this in MRV2? We should not just port over everything blindly, we would like to deprecate as much as possible.

Agree. But I'm not sure how to decide whether any feature is needed. Do we have any target or plan for MRV2?

@mergify

mergify Bot commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gcanlin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 3, 2026
@qthequartermasterman

Copy link
Copy Markdown
Contributor

Thanks @gcanlin

But are we sure we want to support this in MRV2? We should not just port over everything blindly, we would like to deprecate as much as possible.

@njhill I for one would be greatly disappointed if prompt_embeds support was deprecated. My business is built on it, and if it were removed, I'd have to fork vLLM to continue doing business.

It has a variety of use cases. The most common is to train custom MM encoders for modalities not supported natively in vLLM or for models that don't have natively trained MM encoders. A classic example is training a vision encoder that outputs in the token embedding space of a pure-text model, like Nemotron 3 Super. A user can encode their images directly to prompt embeddings, and send those embeddings either alongside text embeddings or as a prompt_embeds part in the /v1/chat/completions endpoint. I've seen various people use this feature to create things like audio encoders for Llama (in lieu of a separate speech-to-text step that would lose semantic details like tone), graph encoders for Mistral, etc... I don't have links unfortunately for these examples, because they've come from offline discussions with several people.

Another use case I've seen is using prompt embeds to compress the number of tokens in a request. With a clever encoder you can compress dozens of text tokens into a single prompt embed.

My particular use case creates a "privacy encoder" (in some sense). We train an encoder that takes in text and outputs a sequence of prompt embeddings that the language model still natively understands, but do not correspond to text in any meaningfully reversible sense without that target model. https://protopia.ai/stained-glass-transform/

People do use this feature, even if it's not super common, evidenced by the issues and PRs that open around it every so often.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants