Checklist
Motivation
- We conducted systematic optimizations around multimodal data preprocessing and the ViT module, ultimately achieving significant end-to-end performance improvements.
- We believe that some of this work will benefit the inference performance of the Qwen3-VL series models, while other parts will benefit inference performance for a broader range of multimodal models.
- Therefore, we have summarized the optimization techniques we used and hope to discuss them with the community here. Afterwards, we hope to break these optimizations down and contribute them as separate PRs.
Modifications
- We focus on scenario (NVIDIA Hopper GPU + Qwen3-VL-235B + TP8 + SGLang 0.5.8 or main branch): one request with 20 high-resolution images (960×1280) and short text (approximately ISL=96000 for ViT and ISL=24000 for LLM). The performance data is based on this configuration.
- Related performance comparisons before and after optimization are listed in the table below.
| Category |
Operation |
Before Optimization |
|
After Optimization |
|
Speedup |
|
|
Latency (ms) |
TTFT (%) |
Latency (ms) |
TTFT (%) |
|
| Preprocess |
Image Decoding |
4.0 |
0.5% |
25.2 |
4.9% |
0.2 |
|
Convert to torch CPU tensor |
82.6 |
10.8% |
Eliminated |
0.0% |
inf |
|
Convert to torch GPU tensor |
11.8 |
1.5% |
Eliminated |
0.0% |
inf |
|
Total (including other ops) |
101.0 |
13.2% |
26.5 |
5.2% |
3.8 |
| VisionModel |
PatchEmbed |
27.8 |
3.6% |
2.4 |
0.5% |
11.6 |
|
fast_pos_embed_interpolate |
24.0 |
3.1% |
8.3 |
1.6% |
2.9 |
|
rot_pos_emb |
6.3 |
0.8% |
2.9 |
0.6% |
2.2 |
|
ViT |
383.5 |
50.1% |
244.5 |
47.9% |
1.6 |
|
Total (including other ops) |
446.6 |
58.3% |
263.5 |
51.7% |
1.7 |
| LLM Prefill |
|
156.1 |
20.4% |
156.1 |
30.6% |
1.0 |
| TTFT |
Overall |
766.2 |
100.0% |
510.0 |
100.0% |
1.5 |
📗 Multimodal Data Preprocessing Pipeline
- Images (base64-encoded) in the request require: base64 string → bytes → PIL.Image → CPU tensor → GPU tensor → normalize and split patches → broadcast to GPUs → ViT.
- For JPEG format, we switched to using
torch.ops.image.decode_jpegs_cuda, converting CPU bytes directly to GPU tensors using the nvJPEG hardware decoder. This eliminates intermediate data formats such as PIL Images and CPU tensors, and minimizes CPU-GPU data transfers.
- Overall, we achieved a 3.8× speedup.
- We hope to incorporate this optimization as a separate PR into sglang to accelerate the preprocessing of JPEG images. This workflow can also be used for JPEG2000 and TIFF images now, and for PNG images soon. Related information.
- The code modification involves two files:
srt/utls/common. py and srt/multimodal/processors/basesprocessor. py
📕 Operator Execution Efficiency
📖 Replace Conv3d with Linear in Patch Embedding (Qwen3VLVisionPatchEmbed)
- The shapes of the weight, bias, and input tensor for the Conv3d operator are $[1152, 3, 2, 16, 16]$, $[1152]$, and $[N, 3, 2, 16, 16]$ respectively (in our example, $N=(1280/16) \times (960/16) \times 20 = 96000$).
- Evidently, this convolution kernel does not slide over the input tensor but has only a single position, which is equivalent to reshaping the weight to $w' \in \mathbb{R}^{1152 \times 1536}$, reshaping the input tensor to $x' \in \mathbb{R}^{N \times 1536}$, and then performing a linear transformation $y = x' \cdot (w')^T + b$. Thus, it can be replaced with a Linear layer.
- After this optimization, we achieved a 11.6× speedup.
- Note: this improvement remains significantly superior to the original Conv3d implementation even after upgrading to cuDNN 9.16+ with PyTorch 2.9.1.
- We hope to incorporate this optimization as a separate PR into sglang to improve the performance of the VLM part. The code modification involves two files:
srt/models/qwen3-vl. py and srt/model-loader/loaderman. py
📖 Optimize position embedding (fast_pos_embed_interpolate) and rotary position embedding (rot_pos_emb) operators
- The original code for these operators mixed Python lists, NumPy arrays, and PyTorch tensors, resulting in numerous memory copies and fragmented elementwise operations due to frequent device synchronization.
- We optimized these operators by performing most computations on the CPU, with only the final gather operation and elementwise operations executed on the GPU.
- Ultimately, these two functions achieved 2.9× and 2.2× speedups respectively.
- We hope to incorporate this optimization as a separate PR into sglang to improve the performance of the VLM part. The code modification involves one file:
srt/models/qwen3-vl. py.
- However, in a wider range of tests, the performance of the new functions
fast_pos_imbed_interpolate-v2, fast_pos_imbed_interpolate-v3, rot_pos_imbemv2 fluctuates and is weaker than the original implementation under some testing conditions. More testing and validation are needed here.
📘 Cross-Device Communication Overhead
- SGLang's native TP strategy resulted in 30.7% of the time being spent on communication (
ncclDevKernel_AllReduce_Sum_bf16_RING_LL kernel) in our tests.
- We have validated that the Ulysses strategy effectively reduces inter-GPU communication in this scenario, achieving a 1.6× speedup with FlashAttention-3 as the backend.
- Note:
- This PR demonstrates the integration interface of using Ulysses, while the underlying implementation has not been attached in the PR yet. It can be found here, which named "vfly" in this PR but "visual_gen" in the link repo.
- Our ultimate goal is to integrate these optimizations into FlashInfer, so that when Sequence Parallelism (or named Context Parallelism) is enabled in sglang, it can be automatically used for computation.
- At present, our call stack is still in the state of:
sglang -> our code -> FlashInfer, which provides a temporary solution for using SP (CP) on ViT in sglang.
We would appreciate feedback from the SGLang maintainers on the merit of this optimization direction. If there is positive reception, we are committed to upstreaming the core implementation into SGLang for further discussion and review.
Checklist
Related PR: #18559
Motivation
Modifications
📗 Multimodal Data Preprocessing Pipeline
torch.ops.image.decode_jpegs_cuda, converting CPU bytes directly to GPU tensors using the nvJPEG hardware decoder. This eliminates intermediate data formats such as PIL Images and CPU tensors, and minimizes CPU-GPU data transfers.srt/utls/common. pyandsrt/multimodal/processors/basesprocessor. py📕 Operator Execution Efficiency
📖 Replace
Conv3dwithLinearin Patch Embedding (Qwen3VLVisionPatchEmbed)srt/models/qwen3-vl. pyandsrt/model-loader/loaderman. py📖 Optimize position embedding (
fast_pos_embed_interpolate) and rotary position embedding (rot_pos_emb) operatorssrt/models/qwen3-vl. py.fast_pos_imbed_interpolate-v2,fast_pos_imbed_interpolate-v3,rot_pos_imbemv2fluctuates and is weaker than the original implementation under some testing conditions. More testing and validation are needed here.📘 Cross-Device Communication Overhead
ncclDevKernel_AllReduce_Sum_bf16_RING_LLkernel) in our tests.sglang -> our code -> FlashInfer, which provides a temporary solution for using SP (CP) on ViT in sglang.We would appreciate feedback from the SGLang maintainers on the merit of this optimization direction. If there is positive reception, we are committed to upstreaming the core implementation into SGLang for further discussion and review.