Skip to content

[wip] add the torch.float8_e4m3fn data type#97798

Closed
vkuzo wants to merge 1 commit intomainfrom
float8_dtype
Closed

[wip] add the torch.float8_e4m3fn data type#97798
vkuzo wants to merge 1 commit intomainfrom
float8_dtype

Conversation

@vkuzo
Copy link
Copy Markdown
Contributor

@vkuzo vkuzo commented Mar 28, 2023

Summary:

This PR adds initial support for the 8-bit floating point types, matching the spec in https://arxiv.org/pdf/2209.05433.pdf . We are adding this to PyTorch to enable easier experimentation with these dtypes by the research community.

Note: we do not yet have an RFC on a full end to end experience in PyTorch for training and inference with float8, as it is too early to know what a good UEX will look like. Specifically, we need to better understand the performance and accuracy implications of scaling factor calculation to inform the UEX design.

Note: we are aware that there are other specifications of float8, and float8 is not yet in the IEEE spec. We are adding these dtypes to facilitate easier experimentation with the hardware which is currently available, and we are happy to follow-up on other float8 flavors when there is a need.

Note: the scaling factor is outside of the float8 dtypes, and is expected to be implemented by the user workflow.

Note: the current PR only has torch.float8_e4m3fn, to speed up initial review and implementation of any requested changes. Once this passes initial review, I will add torch.float8_e5m2 in the same manner to this PR. It will be easiest to land the two dtypes together due to some Meta-only changes which will have to be done on the Meta-only version of this PR.

The specific support added here is:

// tensor creation with `torch.zeros`
x = torch.zeros(4, dtype=torch.float8_e4m3fn)

// conversion to and from float8 dtypes
// conversion between float32 and float8 uses the kernels copied from fbgemm
// conversion between other dtypes and float8 converts to float32 as an intermediate step
x = torch.randn(4, dtype=torch.float)
x_float8 = x.to(torch.float8_e4m3fn)
x_float32 = x_float8.to(torch.float32)

// printing out the tensor for debugging
// note: the values are converted to float32 before being printed
print(x_float8)

// numerical correctness of the format (see test cases)

The following is not in this PR but which we plan to address in the near future:

  • ensure all of the functionality above works on CUDA

Test plan:

python test/test_quantization.py TestFloat8

Fixes #ISSUE_NUMBER

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@vkuzo vkuzo requested a review from jerryzh168 as a code owner March 28, 2023 17:16
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97798

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Failures

As of commit 9e6f507:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: quantization release notes category label Mar 28, 2023
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Mar 28, 2023
Summary:

This PR adds initial support for the 8-bit floating point types, matching
the spec in https://arxiv.org/pdf/2209.05433.pdf . We are adding this to
PyTorch to enable easier experimentation with these dtypes by the research
community.

Note: we do not yet have an RFC on a full end to end experience in PyTorch
for training and inference with float8, as it is too early to know what a
good UEX will look like. Specifically, we need to better understand the
performance and accuracy implications of scaling factor calculation to inform
the UEX design.

Note: we are aware that there are other specifications of float8, and
float8 is not yet in the IEEE spec. We are adding these dtypes to facilitate
easier experimentation with the hardware which is currently available, and
we are happy to follow-up on other float8 flavors when there is a need.

Note: the scaling factor is outside of the `float8` dtypes, and is expected
to be implemented by the user workflow.

Note: the current PR only has `torch.float8_e4m3fn`, to speed up initial review
and implementation of any requested changes. Once this passes initial review,
I will add `torch.float8_e5m2` in the same manner to this PR. It will be easiest
to land the two dtypes together due to some Meta-only changes which will
have to be done on the Meta-only version of this PR.

The specific support added here is:

```
// tensor creation with `torch.zeros`
x = torch.zeros(4, dtype=torch.float8_e4m3fn)

// conversion to and from float8 dtypes
// conversion between float32 and float8 uses the kernels copied from fbgemm
// conversion between other dtypes and float8 converts to float32 as an intermediate step
x = torch.randn(4, dtype=torch.float)
x_float8 = x.to(torch.float8_e4m3fn)
x_float32 = x_float8.to(torch.float32)

// printing out the tensor for debugging
// note: the values are converted to float32 before being printed
print(x_float8)

// numerical correctness of the format
TODO document
```

The following is not in this PR but which we plan to address in the near
future:
* ensure all of the functionality above works on CUDA

Test plan:

```
python test/test_quantization.py TestFloat8
```
@mingfeima
Copy link
Copy Markdown
Collaborator

@vkuzo, my understanding is that this feature is for research purpose, right? Are you aware of any vendors that has hardware acceleration for float8_e4m3fn ?

@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Mar 29, 2023

Are you aware of any vendors that has hardware acceleration for float8_e4m3fn ?

Yes, this is matching the fp8 format used by NVIDIA in their new Hopper architecture.

Note: for now this is a prototype and we have not committed to landing this PR yet - going through deliberations.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small suggestion: can you put the same stuff into OpMathType.h, opmath is used more frequently

@xiaolil1
Copy link
Copy Markdown
Contributor

Hi, may I know what's the latest status of this PR, we'd like to try it on our FP8 implementation.

@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented May 25, 2023

Hi, may I know what's the latest status of this PR, we'd like to try it on our FP8 implementation.

hi @xiaolil1 , this PR is abandoned. We are thinking through the design of fp8 as an out of core dtype for now.

@vkuzo vkuzo closed this May 25, 2023
@jgong5
Copy link
Copy Markdown
Collaborator

jgong5 commented May 26, 2023

Hi, may I know what's the latest status of this PR, we'd like to try it on our FP8 implementation.

hi @xiaolil1 , this PR is abandoned. We are thinking through the design of fp8 as an out of core dtype for now.

hi @vkuzo would that be a new design to support "out of core dtype"? Currently, I guess dtype cannot be extended.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants