[VLM] Chunk-aware ViT encoding with per-image cache and lazy device transfer#22038
[VLM] Chunk-aware ViT encoding with per-image cache and lazy device transfer#22038yhyang201 merged 1 commit intosgl-project:mainfrom
Conversation
|
/tag-and-rerun-ci |
There was a problem hiding this comment.
Code Review
This pull request refactors the multimodal embedding logic to improve efficiency through per-image chunk-aware encoding and centralized device management. Key changes include the introduction of _get_chunked_embedding_by_item for optimized caching and the removal of redundant device transfer logic across various model implementations like Qwen3-VL and DeepSeek-VL2. Feedback highlights that _move_items_to_device should be updated to handle numpy.ndarray features to prevent downstream failures, and a safety check is needed for item.offsets to avoid potential TypeError exceptions.
| if isinstance(item.feature, torch.Tensor) and item.feature.device != device: | ||
| item.feature = item.feature.to(device, non_blocking=True) |
There was a problem hiding this comment.
The _move_items_to_device function only handles torch.Tensor features. However, MultimodalDataItem.feature can also be a numpy.ndarray (as defined in schedule_batch.py). If a feature is a numpy array, it won't be moved to the device, which will cause subsequent model operations (like torch.cat in qwen3_vl.py) to fail. It's safer to convert numpy arrays to tensors before moving them to the device.
for item in items:
if item.feature is not None:
if not isinstance(item.feature, torch.Tensor):
item.feature = torch.from_numpy(item.feature)
if item.feature.device != device:
item.feature = item.feature.to(device, non_blocking=True)| # Use per-image path when all items have exactly one offset (already | ||
| # split per-image) — this avoids encoding images not in this chunk. | ||
| # Fall back to combined path for non-split items or EVS. | ||
| is_per_image = all(len(item.offsets) == 1 for item in embedding_items_per_req) |
There was a problem hiding this comment.
The check len(item.offsets) == 1 will raise a TypeError if item.offsets is None. While most processors set offsets, MultimodalDataItem defines offsets as Optional[list]. A safer check should ensure offsets is not None before accessing its length.
| is_per_image = all(len(item.offsets) == 1 for item in embedding_items_per_req) | |
| is_per_image = all(item.offsets is not None and len(item.offsets) == 1 for item in embedding_items_per_req) |
Per-Image ViT Cache Benchmark ResultsModel: Qwen/Qwen3-VL-8B-Instruct (tp=1) ViT Encoding Time per Chunk Prefill720p (1280x720)
1080p (1920x1080)
2K (2560x1440)
TTFT Comparison (multiturn_image)720p
1080p
2K
Summary
Root cause: On Remaining TTFT gap: Even with ViT savings, TTFT still grows with prompt length due to radix tree lookup, scheduling overhead, and prefilling uncached tokens. These are independent of ViT and not addressed by this PR. |
Image Limit Probe Benchmark ResultsModel: Qwen/Qwen3.5-27B (tp=1) Max Image Limit
TTFT per Image Count720p (1280x720)
1080p (1920x1080)
2K (2560x1440)
Summary
Root cause of OOM on main: The combined embedding cache key changes whenever the image set changes. In probe, every step adds one more image, so the ViT re-encodes all N images from scratch each time. At high resolutions (1080p, 2K), the intermediate ViT activations for large batches exhaust GPU memory. This PR caches per-image embeddings individually, so only the new image is encoded, keeping peak memory constant regardless of total image count. |
OCRBench Accuracy ResultsModel: Qwen/Qwen3.5-27B (tp=1, enable_thinking=False) Configuration[server]
model_path = "Qwen/Qwen3.5-27B"
extra_args = "--port 7893 --tp-size 1 --enable-multimodal"
[accuracy]
tasks = ["ocrbench"]
extra_args = "--max-tokens 8192 --stream --think-mode qwen3 --max-connections 10"
# think-mode qwen3 passes: extra_body = {"chat_template_kwargs": {"enable_thinking": False}}Results
Analysis
|
Motivation
Modifications
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci