-
Notifications
You must be signed in to change notification settings - Fork 7.4k
[Data/LLM] vLLM model files are downloaded to disk even when "load_format": "runai_streamer" is specified in engine_kwargs #55574
Description
What happened + What you expected to happen
Based on documentation, runai_streamer should be supported as a load format in order to stream the model weights from s3 into memory directly. However, when we launch vLLM engines from Ray's documentation here with runai_streamer (https://docs.ray.io/en/latest/data/working-with-llms.html), the model is still being downloaded to disk, which leads to significant startup overhead.
The code that forces model downloading without taking in consideration of whether or not we specified the streaming load format is here:
ray/python/ray/llm/_internal/batch/stages/vllm_engine_stage.py
Lines 465 to 470 in 5f62485
| # Download the model if needed. | |
| model_source = download_model_files( | |
| model_id=self.model, | |
| mirror_config=None, | |
| download_model=NodeModelDownloadable.MODEL_AND_TOKENIZER, | |
| download_extra_files=False, |
I propose that we should check in engine_kwargs whether or not load_format is specified, and if it is specified as "runai_streamer", we should skip the download of the safetensor files (we probably still need to download the config.json, other metadata files).
Versions / Dependencies
ray version 2.48.0
python 3.11
Using the official docker image from dockerhub.
Reproduction script
import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
import numpy as np
config = vLLMEngineProcessorConfig(
model_source="s3://path/to/model/",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 16384,
"load_format": "runai_streamer"
},
concurrency=1,
batch_size=64,
)
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a bot that responds with haikus."},
{"role": "user", "content": row["item"]}
],
sampling_params=dict(
temperature=0.3,
max_tokens=250,
)
),
postprocess=lambda row: dict(
answer=row["generated_text"],
**row # This will return all the original columns in the dataset.
),
)
ds = ray.data.from_items(["Start of the haiku is: Complete this for me..."])
ds = processor(ds)
ds.show(limit=1)Issue Severity
Medium: It is a significant difficulty but I can work around it.