[Kernels] Improve Triton fp8 block scaled kernel#29438
[Kernels] Improve Triton fp8 block scaled kernel#29438lgeiger wants to merge 3 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces performance improvements to the Triton fp8 block scaled kernel by simplifying pointer arithmetic and removing unnecessary masked loads. The changes are well-reasoned and backed by benchmark results. However, I've identified a pre-existing critical bug that could lead to out-of-bounds memory access when loading scale tensors. This occurs when the K dimension is not a multiple of BLOCK_SIZE_K. I've provided a comment with a suggested fix to address this correctness issue.
| a_s = tl.load(As_ptrs) | ||
| b_s = tl.load(Bs_ptrs) |
There was a problem hiding this comment.
There is a potential out-of-bounds memory access when loading the scale tensors a_s and b_s. This can happen when K is not perfectly divisible by BLOCK_SIZE_K, causing the loop to have a final iteration that accesses beyond the bounds of the scale tensors. The scale tensors As and Bs have a size of K // group_k along the K dimension. However, the access offset, which is effectively k * scale_step_k, can exceed this limit in the final loop iteration. This was also an issue in the previous implementation. To prevent this, we should add a mask to the scale loads.
| a_s = tl.load(As_ptrs) | |
| b_s = tl.load(Bs_ptrs) | |
| scale_mask = k * scale_step_k < (K // group_k) | |
| a_s = tl.load(As_ptrs, mask=scale_mask, other=0.0) | |
| b_s = tl.load(Bs_ptrs, mask=scale_mask, other=0.0) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| a_ptrs += BLOCK_SIZE_K * stride_ak | ||
| b_ptrs += BLOCK_SIZE_K * stride_bk | ||
| As_ptrs += scale_step_k * stride_As_k | ||
| Bs_ptrs += scale_step_k * stride_Bs_k |
There was a problem hiding this comment.
Scale pointers stuck when BLOCK_SIZE_K < group_k
The new pointer-stepping logic multiplies scale_step_k = BLOCK_SIZE_K // group_k into the scale strides. When the K tile is smaller than the quantization block (e.g., block_shape [128,128] with tuned configs that set BLOCK_SIZE_K to 64), this integer division is zero, so As_ptrs/Bs_ptrs never advance and every K tile reuses the first scale block. That produces wrong scaling for K offsets beyond the first 128 elements, whereas the previous k_start // group_k offset handled the larger block correctly.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
This sounds sensible. I have a look later
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
Purpose
This PR aims to improve performance of the Triton fp8 block scaled w8a8 kernel.
It's probably best to review it commit by commit:
K % BLOCK_SIZE_K == 0. To me this seems to always be the case, but I might miss some edge cases so I kept a fallback to the previous behaviour when this condition isn't met.Test Plan
Correctness should be covered by
tests/kernels/quantization/test_block_fp8.pyand I also verified it with lm_eval for Qwen3-VL-2B-Instruct-FP8.I tested performance on a L40s with
Qwen3-VL-32B-Instruct-FP8:Test Results
Before:
After code changes:
After code changes and re-tuned:
Overall this improves throughput of
Qwen3-VL-32B-Instruct-FP8on a single L40s by 3.5% which is decent.