Skip to content

[Kernels] Improve Triton fp8 block scaled kernel#29438

Closed
lgeiger wants to merge 3 commits intovllm-project:mainfrom
lgeiger:fp8-block-scaled-mm
Closed

[Kernels] Improve Triton fp8 block scaled kernel#29438
lgeiger wants to merge 3 commits intovllm-project:mainfrom
lgeiger:fp8-block-scaled-mm

Conversation

@lgeiger
Copy link
Copy Markdown
Contributor

@lgeiger lgeiger commented Nov 25, 2025

Purpose

This PR aims to improve performance of the Triton fp8 block scaled w8a8 kernel.

It's probably best to review it commit by commit:

Test Plan

Correctness should be covered by tests/kernels/quantization/test_block_fp8.py and I also verified it with lm_eval for Qwen3-VL-2B-Instruct-FP8.

I tested performance on a L40s with Qwen3-VL-32B-Instruct-FP8:

vllm serve Qwen/Qwen3-VL-32B-Instruct-FP8 --limit-mm-per-prompt.video 0 --limit-mm-per-prompt.image 0 --max-model-len 24000 --no-enable-prefix-caching

vllm bench serve --backend vllm --model Qwen/Qwen3-VL-32B-Instruct-FP8 --endpoint /v1/completions --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000

Test Results

Before:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  282.88
Total input tokens:                      217393
Total generated tokens:                  189963
Request throughput (req/s):              3.54
Output token throughput (tok/s):         671.54
Peak output token throughput (tok/s):    1477.00
Peak concurrent requests:                1000.00
Total Token throughput (tok/s):          1440.06
---------------Time to First Token----------------
Mean TTFT (ms):                          115438.24
Median TTFT (ms):                        119080.39
P99 TTFT (ms):                           239007.78
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          129.52
Median TPOT (ms):                        104.82
P99 TPOT (ms):                           540.94
---------------Inter-token Latency----------------
Mean ITL (ms):                           104.46
Median ITL (ms):                         66.16
P99 ITL (ms):                            438.74
==================================================

After code changes:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  279.72
Total input tokens:                      217393
Total generated tokens:                  189963
Request throughput (req/s):              3.58
Output token throughput (tok/s):         679.12
Peak output token throughput (tok/s):    1707.00
Peak concurrent requests:                1000.00
Total Token throughput (tok/s):          1456.30
---------------Time to First Token----------------
Mean TTFT (ms):                          117722.64
Median TTFT (ms):                        122529.00
P99 TTFT (ms):                           247590.06
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          133.14
Median TPOT (ms):                        106.13
P99 TPOT (ms):                           503.34
---------------Inter-token Latency----------------
Mean ITL (ms):                           106.04
Median ITL (ms):                         66.58
P99 ITL (ms):                            491.64
==================================================

After code changes and re-tuned:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  273.16
Total input tokens:                      217393
Total generated tokens:                  189963
Request throughput (req/s):              3.66
Output token throughput (tok/s):         695.43
Peak output token throughput (tok/s):    1644.00
Peak concurrent requests:                1000.00
Total Token throughput (tok/s):          1491.28
---------------Time to First Token----------------
Mean TTFT (ms):                          117644.34
Median TTFT (ms):                        123531.08
P99 TTFT (ms):                           240034.56
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          122.76
Median TPOT (ms):                        102.63
P99 TPOT (ms):                           511.67
---------------Inter-token Latency----------------
Mean ITL (ms):                           102.85
Median ITL (ms):                         71.09
P99 ITL (ms):                            518.81
==================================================

Overall this improves throughput of Qwen3-VL-32B-Instruct-FP8 on a single L40s by 3.5% which is decent.

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
@mergify mergify Bot added the performance Performance-related issues label Nov 25, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance improvements to the Triton fp8 block scaled kernel by simplifying pointer arithmetic and removing unnecessary masked loads. The changes are well-reasoned and backed by benchmark results. However, I've identified a pre-existing critical bug that could lead to out-of-bounds memory access when loading scale tensors. This occurs when the K dimension is not a multiple of BLOCK_SIZE_K. I've provided a comment with a suggested fix to address this correctness issue.

Comment on lines +728 to +729
a_s = tl.load(As_ptrs)
b_s = tl.load(Bs_ptrs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a potential out-of-bounds memory access when loading the scale tensors a_s and b_s. This can happen when K is not perfectly divisible by BLOCK_SIZE_K, causing the loop to have a final iteration that accesses beyond the bounds of the scale tensors. The scale tensors As and Bs have a size of K // group_k along the K dimension. However, the access offset, which is effectively k * scale_step_k, can exceed this limit in the final loop iteration. This was also an issue in the previous implementation. To prevent this, we should add a mask to the scale loads.

Suggested change
a_s = tl.load(As_ptrs)
b_s = tl.load(Bs_ptrs)
scale_mask = k * scale_step_k < (K // group_k)
a_s = tl.load(As_ptrs, mask=scale_mask, other=0.0)
b_s = tl.load(Bs_ptrs, mask=scale_mask, other=0.0)

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 732 to +735
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
As_ptrs += scale_step_k * stride_As_k
Bs_ptrs += scale_step_k * stride_Bs_k
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Scale pointers stuck when BLOCK_SIZE_K < group_k

The new pointer-stepping logic multiplies scale_step_k = BLOCK_SIZE_K // group_k into the scale strides. When the K tile is smaller than the quantization block (e.g., block_shape [128,128] with tuned configs that set BLOCK_SIZE_K to 64), this integer division is zero, so As_ptrs/Bs_ptrs never advance and every K tile reuses the first scale block. That produces wrong scaling for K offsets beyond the first 128 elements, whereas the previous k_start // group_k offset handled the larger block correctly.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds sensible. I have a look later

@github-actions
Copy link
Copy Markdown

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions Bot added the stale Over 90 days of inactivity label Feb 24, 2026
@lgeiger lgeiger closed this Mar 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues stale Over 90 days of inactivity

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant