[Inductor][float8] Support qlinear for float8 in inductor#2565
Conversation
Add fp8 dequant promotion
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2565
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 4fb5f7a with merge base 8e2ca35 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@jerryzh168 Could you help review this pr |
|
@jerryzh168 Could you help review this pr? |
|
@shiyang-weng the registration PR is reverted in #2672, we'd need to land that again without breaking BC |
Hi @jerryzh168 Could you please provide a reproducer so that we can fix that? Thanks. |
|
yeah this is the test: ao/test/dtypes/test_affine_quantized_float.py Line 735 in 418593c |
There was a problem hiding this comment.
Pull Request Overview
This PR adds float8_e4m3fn support to PyTorch Inductor for qlinear operations, implementing quantization patterns specifically for FP8 data types. The implementation handles differences in FP8 quantization API requirements, including tensor-based scales and modified quantize/dequantize operations.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| torchao/quantization/pt2e/inductor_passes/x86.py | Adds FP8 quantization support with new patterns, updates existing functions to handle FP8 operations, and modifies view operation handling |
| test/quantization/pt2e/test_x86inductor_fusion.py | Adds comprehensive test coverage for FP8 quantization patterns and refactors test helpers to support FP8 |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| x_zp = kwargs["x_zp"] if "x_zp" in kwargs else None | ||
| w_zp = kwargs["w_zp"] if "w_zp" in kwargs else None |
There was a problem hiding this comment.
[nitpick] The extraction of qparams has inconsistent patterns. The first two use tuple unpacking while x_zp and w_zp use conditional extraction. For better maintainability and consistency, consider using the same pattern for all parameters.
| x_zp = kwargs["x_zp"] if "x_zp" in kwargs else None | |
| w_zp = kwargs["w_zp"] if "w_zp" in kwargs else None | |
| x_zp = kwargs.get("x_zp") | |
| w_zp = kwargs.get("w_zp") |
| is_tensor_overload, | ||
| is_fp8, | ||
| ) in linear_weight_prepack_cases: | ||
| if is_fp8 and not is_tensor_overload: |
There was a problem hiding this comment.
[nitpick] This skip condition appears in multiple places (lines 1429 and 1506). Consider extracting this logic into a helper function or constant to avoid code duplication and improve maintainability.
| if is_fp8 and not is_tensor_overload: | |
| if _should_skip_fp8_case(is_fp8, is_tensor_overload): |
| if output_dtype == torch.float8_e4m3fn: | ||
| # For float8, torchao.quantize_affine_float8 requires tensor as scale | ||
| # Support scale node is full firstly | ||
| assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default |
There was a problem hiding this comment.
The assertion assumes kwargs[\"o_inv_scale\"] is always a node object, but there's no validation that it has a target attribute. This could cause an AttributeError if the object doesn't have this attribute.
| assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default | |
| assert hasattr(kwargs["o_inv_scale"], "target") and kwargs["o_inv_scale"].target is torch.ops.aten.full.default, ( | |
| "Expected kwargs['o_inv_scale'] to be a node object with 'target' attribute set to torch.ops.aten.full.default" | |
| ) |
| # check if scale created by torch.tensor | ||
| return ( | ||
| len(node.all_input_nodes) == 2 | ||
| and node.all_input_nodes[1].target == torch.tensor |
There was a problem hiding this comment.
[nitpick] Using torch.tensor as a target comparison might be fragile since it's comparing against a function object. Consider using a more robust method to identify tensor creation nodes, such as checking the function name or using a more specific target.
| and node.all_input_nodes[1].target == torch.tensor | |
| and torch.fx.node._qualified_name(node.all_input_nodes[1].target) == "torch.tensor" |
| class FP8QDQLinear(torch.nn.Module): | ||
| def __init__(self, in_features, out_features, has_bias): | ||
| super().__init__() | ||
| self.qtype = torch.float8_e4m3fn | ||
| self.weight = torch.randn((out_features, in_features)).to(self.qtype) | ||
| self.weight_scale = 2.0 | ||
| self.scale = 2.0 | ||
| self.bias = None | ||
| if has_bias: | ||
| self.bias = torch.randn((out_features,)) |
There was a problem hiding this comment.
[nitpick] The hardcoded scale values (2.0) should be configurable parameters or documented constants to improve maintainability and make the test more flexible.
| if is_fp8: | ||
| # fp8_convert_ not support dynamic and qat yet | ||
| assert not is_dynamic | ||
| assert not is_qat |
There was a problem hiding this comment.
[nitpick] This assertion pattern appears multiple times in the test file (lines 206-208 and 1954-1957). Consider extracting this validation into a helper function to reduce code duplication.
|
This PR used for support fp8 on PT. |
|
CC @mingfeima for review |
jerryzh168
left a comment
There was a problem hiding this comment.
do you need to change the quant flow code to produce this op?
I'd recommend to do this by defining a new observer, use this API:
ao/test/quantization/pt2e/test_quantize_pt2e.py
Line 2315 in c96f2dd
* quantize_affine_float8/dequantize_affine_float8 not decomposed on inductor * remove redundant unittest.skipIf * fix rebase issue * change dispatch key to a flag decomposed * support scaled_mm on inductor * fix rebase issue * support dequant promtion for fp8 * add ut * remove redundant codes * fix lint * resolve conflict * change to use qlinear * add ut * fix lint * support fp8 quant_lift_up * add reshape into _VIEW_METHOD_OPS * add quant_input_check * fix lint * refine ut * remove fp8 dynamic quant ut * fix output_scale issue * add float8_e4m3fn to dtype_list * refine code * refine code * fix bugs * add comment * merge main * change to use non-decomposed q/dq * fix lint * add version check * change version * fix attention bug; update ut * add liftup oplist
* quantize_affine_float8/dequantize_affine_float8 not decomposed on inductor * remove redundant unittest.skipIf * fix rebase issue * change dispatch key to a flag decomposed * support scaled_mm on inductor * fix rebase issue * support dequant promtion for fp8 * add ut * remove redundant codes * fix lint * resolve conflict * change to use qlinear * add ut * fix lint * support fp8 quant_lift_up * add reshape into _VIEW_METHOD_OPS * add quant_input_check * fix lint * refine ut * remove fp8 dynamic quant ut * fix output_scale issue * add float8_e4m3fn to dtype_list * refine code * refine code * fix bugs * add comment * merge main * change to use non-decomposed q/dq * fix lint * add version check * change version * fix attention bug; update ut * add liftup oplist
For float8_e4m3fn, support
on inductor.
For FP8, there are following issues
Based on these issues,