Skip expanding scales for rowwise fp8 quantize#2950
Merged
Conversation
780003d to
935ac1a
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2950
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4cf5c90 with merge base 4872c4f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
drisspg
reviewed
Sep 6, 2025
drisspg
reviewed
Sep 6, 2025
| return scale | ||
|
|
||
| # For rowwise quantization, just return the scale as is | ||
| if scale.shape[:-1] == target_shape[:-1] and scale.shape[-1] == 1: |
Contributor
There was a problem hiding this comment.
you could probably do something fun, like
def is_trivial_expandable(scale, target_shape):
return all(a == b or a == 1 for a, b in zip(scale.shape, target_shape))**Summary:** #2253 added a step in `quantize_affine_float8` to expand the scales for blockwise quantization. The purpose of this step is to make the scales always broadcastable with the input tensor. However, this is unnecessary for rowwise quantization, which already has broadcastable shapes, e.g. ``` scale = [32, 1] input = [32, 16] ``` Today, we will `repeat_interleave` the above scales to pad the scale tensor until it reaches `[32, 16]`, which adds non-trivial memory and latency overhead. This commit adds a fast path to skip this expanding step if we detect rowwise quantization. **Test Plan:** ``` python test/quantization/test_quant_primitives.py -k test_maybe_expand_scale_to_tensor_shape ``` Also compared fine-tuning Qwen3-1.7B with fp8-fp8 QAT using batch size 32 on a single H100 GPU: - Before: 25.34 GB peak memory, 3047.25 tok/s - After: 22.53 GB peak memory, 3358.49 tok/s - This PR uses 11.1% less memory and is 10.2% faster
935ac1a to
4cf5c90
Compare
drisspg
approved these changes
Sep 8, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary: #2253 added a step in
quantize_affine_float8to expand the scales for blockwise quantization. The purpose of this step is to make the scales always broadcastable with the input tensor. However, this is unnecessary for rowwise quantization, which already has broadcastable shapes, e.g.Today, we will
repeat_interleavethe above scales to pad the scale tensor until it reaches[32, 16], which adds non-trivial memory and latency overhead. This commit adds a fast path to skip this expanding step if we detect rowwise quantization.Test Plan:
Also compared fine-tuning Qwen3-1.7B with fp8-fp8 QAT using batch size 32 on a single H100 GPU: