[Quant] add FP8 support in quantize ops#153601
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153601
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a88490a with merge base fe285b9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Xia-Weiwen
left a comment
There was a problem hiding this comment.
My suggestion on PR title: [Quant] add FP8 support in quantize ops
| @@ -3,6 +3,7 @@ | |||
| from typing import Optional | |||
There was a problem hiding this comment.
we are deprecating these, can you add to torchao? or maybe it's already supported in torchao
https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
There was a problem hiding this comment.
Yeah. Looks like there isn't such an issue with the ops in Torchao.
Done |
For support pytorch/ao#2228 > What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph. > > However we met problems with these q/dq ops both in the PyTorch core and Torchao. > > PyTorch core: > > The quantize_per_tensor op does not support FP8. We want to fix it via #153601. And as you commented, the op is deprecated. > Torchao: > > In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: > https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 > After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. > For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because > It is an op from Torchao, which is unknown to the constant folder > It is decomposed to smaller ops, so we cannot put it in the list as a single op. > So, we think an easy and short-term solution is to modify the ops in PyTorch core via #153601. > However, if we want to resolve the issue with Torchao, we need to > Add a method in the constant folder in Inductor to allow registration of impure ops Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch Pull Request resolved: #154945 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
For support pytorch/ao#2228 > What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph. > > However we met problems with these q/dq ops both in the PyTorch core and Torchao. > > PyTorch core: > > The quantize_per_tensor op does not support FP8. We want to fix it via pytorch#153601. And as you commented, the op is deprecated. > Torchao: > > In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: > https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 > After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. > For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because > It is an op from Torchao, which is unknown to the constant folder > It is decomposed to smaller ops, so we cannot put it in the list as a single op. > So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch#153601. > However, if we want to resolve the issue with Torchao, we need to > Add a method in the constant folder in Inductor to allow registration of impure ops Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch Pull Request resolved: pytorch#154945 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
For support pytorch/ao#2228 > What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph. > > However we met problems with these q/dq ops both in the PyTorch core and Torchao. > > PyTorch core: > > The quantize_per_tensor op does not support FP8. We want to fix it via #153601. And as you commented, the op is deprecated. > Torchao: > > In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: > https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 > After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. > For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because > It is an op from Torchao, which is unknown to the constant folder > It is decomposed to smaller ops, so we cannot put it in the list as a single op. > So, we think an easy and short-term solution is to modify the ops in PyTorch core via #153601. > However, if we want to resolve the issue with Torchao, we need to > Add a method in the constant folder in Inductor to allow registration of impure ops Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch Pull Request resolved: #154945 Approved by: https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
|
supported on torchao |
Quant used to be used for integers.
But now we want to use it for fp8.
This patch determine whether to round according to dtype.
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov