Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97798
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 FailuresAs of commit 9e6f507: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: This PR adds initial support for the 8-bit floating point types, matching the spec in https://arxiv.org/pdf/2209.05433.pdf . We are adding this to PyTorch to enable easier experimentation with these dtypes by the research community. Note: we do not yet have an RFC on a full end to end experience in PyTorch for training and inference with float8, as it is too early to know what a good UEX will look like. Specifically, we need to better understand the performance and accuracy implications of scaling factor calculation to inform the UEX design. Note: we are aware that there are other specifications of float8, and float8 is not yet in the IEEE spec. We are adding these dtypes to facilitate easier experimentation with the hardware which is currently available, and we are happy to follow-up on other float8 flavors when there is a need. Note: the scaling factor is outside of the `float8` dtypes, and is expected to be implemented by the user workflow. Note: the current PR only has `torch.float8_e4m3fn`, to speed up initial review and implementation of any requested changes. Once this passes initial review, I will add `torch.float8_e5m2` in the same manner to this PR. It will be easiest to land the two dtypes together due to some Meta-only changes which will have to be done on the Meta-only version of this PR. The specific support added here is: ``` // tensor creation with `torch.zeros` x = torch.zeros(4, dtype=torch.float8_e4m3fn) // conversion to and from float8 dtypes // conversion between float32 and float8 uses the kernels copied from fbgemm // conversion between other dtypes and float8 converts to float32 as an intermediate step x = torch.randn(4, dtype=torch.float) x_float8 = x.to(torch.float8_e4m3fn) x_float32 = x_float8.to(torch.float32) // printing out the tensor for debugging // note: the values are converted to float32 before being printed print(x_float8) // numerical correctness of the format TODO document ``` The following is not in this PR but which we plan to address in the near future: * ensure all of the functionality above works on CUDA Test plan: ``` python test/test_quantization.py TestFloat8 ```
|
@vkuzo, my understanding is that this feature is for research purpose, right? Are you aware of any vendors that has hardware acceleration for float8_e4m3fn ? |
Yes, this is matching the fp8 format used by NVIDIA in their new Hopper architecture. Note: for now this is a prototype and we have not committed to landing this PR yet - going through deliberations. |
There was a problem hiding this comment.
Small suggestion: can you put the same stuff into OpMathType.h, opmath is used more frequently
|
Hi, may I know what's the latest status of this PR, we'd like to try it on our FP8 implementation. |
hi @xiaolil1 , this PR is abandoned. We are thinking through the design of fp8 as an out of core dtype for now. |
hi @vkuzo would that be a new design to support "out of core dtype"? Currently, I guess dtype cannot be extended. |
Summary:
This PR adds initial support for the 8-bit floating point types, matching the spec in https://arxiv.org/pdf/2209.05433.pdf . We are adding this to PyTorch to enable easier experimentation with these dtypes by the research community.
Note: we do not yet have an RFC on a full end to end experience in PyTorch for training and inference with float8, as it is too early to know what a good UEX will look like. Specifically, we need to better understand the performance and accuracy implications of scaling factor calculation to inform the UEX design.
Note: we are aware that there are other specifications of float8, and float8 is not yet in the IEEE spec. We are adding these dtypes to facilitate easier experimentation with the hardware which is currently available, and we are happy to follow-up on other float8 flavors when there is a need.
Note: the scaling factor is outside of the
float8dtypes, and is expected to be implemented by the user workflow.Note: the current PR only has
torch.float8_e4m3fn, to speed up initial review and implementation of any requested changes. Once this passes initial review, I will addtorch.float8_e5m2in the same manner to this PR. It will be easiest to land the two dtypes together due to some Meta-only changes which will have to be done on the Meta-only version of this PR.The specific support added here is:
The following is not in this PR but which we plan to address in the near future:
Test plan:
Fixes #ISSUE_NUMBER
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10