Skip to content

Add torch.float8_e5m2 and torch.float8_e4m3 data types#104242

Closed
australopitek wants to merge 16 commits intopytorch:mainfrom
australopitek:float8
Closed

Add torch.float8_e5m2 and torch.float8_e4m3 data types#104242
australopitek wants to merge 16 commits intopytorch:mainfrom
australopitek:float8

Conversation

@australopitek
Copy link
Copy Markdown
Contributor

@australopitek australopitek commented Jun 27, 2023

Proposal of two float8 variants - e5m2 and e4m3 - based on https://arxiv.org/pdf/2209.05433.pdf

Hide all Float8 operator implementations behind #if !defined(C10_MOBILE) guard to keep Android build size almost unchanged

TODO:

  • Refactor duplicated code
  • Cleanup unbalanced pragma pop in dtype utils
  • Add native implementation on the CUDA size

Co-authored-by: Nikita Shulga nshulga@meta.com

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @EikanWang

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jun 27, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104242

Note: Links to docs will display an error until the docs builds have been completed.

✅ 1 Unrelated Failure

As of commit a4fa283:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: linalg_frontend release notes category label Jun 27, 2023
@github-actions github-actions bot added module: cpu CPU specific problem (e.g., perf, algorithm) NNC release notes: quantization release notes category labels Jun 27, 2023
@vadimkantorov
Copy link
Copy Markdown
Contributor

tangentially related: #69364

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks amazing! Thanks a lot for sharing all this hard work to help!
I'll let @vkuzo pick it up from here!

Copy link
Copy Markdown
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks great and it's a lot more complete and battle tested than the previous time we tried this! I had a couple of initial questions inline.

@australopitek , as @albanD mentioned we finally agreed to add these to core. We owe you a more detailed review next week, but at a high level, would you be up for getting this PR to green CI / ready to land, or do you prefer that we take it as is and finish on our side?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this intended to be used with scale? The current autocast design is stateless, so scaled float8 doesn't cleanly fit in. If this is intended to be used without scale, I'd love to learn on how one might achieve competitive accuracy.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this shouldn't be there. We were exploring possibilities to use fp8 with autocast, but finally decided to use custom solutions.
In fact, I see that AT_FORAUTOCAST_SCALAR_TYPES is not use anywhere, the only usage were removed quite long time ago. I will clean this during further work.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, are you folks using these unscaled arithmetic operations on float8 for real use cases, or is this just for completeness?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for completeness and simple tests. We use scaled float8 operations.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! I didn't read line by line yet, but high level question - what's your level of confidence on the correctness of these functions? Have these been matched up with NVIDIA's casts or https://github.com/pytorch/FBGEMM/blob/277677039bae25b2570a73013b03bfaa9d2a523e/fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was quite extensively tested against Gaudi2, which supports these variants natively. However, in next commits, I'll add comparative tests for all bits configurations against some reference implementation.
Can you guarantee that fbgemm implementation is fully aligned with the specification? Also, do you know if there is some open fp8_e5m2 reference implementation I could use for testing?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fbgemm implementation has been used extensively at Meta but it hasn't been matched up against real hardware afaik.

Your test cases look good, I'd say if future changes on numerics are needed that can be done in future PRs and not block this one. LGTM.

@australopitek
Copy link
Copy Markdown
Contributor Author

@vkuzo ,
I'll be happy to bring this PR to the end. I'll add more tests and fix current CI issues. Probably mainly alignment with coding guidelines/linters, since I didn't focus on them so far.
Waiting for the full review from you.

Copy link
Copy Markdown
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks great to me and this fits into how we want to extend the ecosystem around float8. Thanks @australopitek for getting this PR to a really good state! I only have a couple of nit comments:

  1. can we match MLIR (https://mlir.llvm.org/doxygen/classmlir_1_1FloatType.html) and XLA and go with Float8_E4M3FN instead of Float8_E4M3?
  2. can we also get a review from @ezyang or @albanD on the framework parts of the PR?

Using your current implementation of the casts as an initial reference LGTM, we can adjust as needed in future PRs.

@australopitek
Copy link
Copy Markdown
Contributor Author

@vkuzo ,
sure, I'll adjust naming to the MLIR. I'll try to push all changes and corrections this week.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably will need to replace this with a primitive at some point as this is highly unlikely to be optimal.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to helper method

inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
  return (x & 0b01111111) == 0b01111111;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b8/h8 are funny names to call these lol.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there funny enough to be changed, or can I leave there?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This says that when you combine Float8_e5m2 and Float8_e4m3, it promotes to Float8_e4m3. This seems suspicious to me, as it's inconsistent with float16 and bfloat16 promotion, which promotes to float32 per the RFC at #43049 Additionally, JAX currently doesn't support promotion between these two types. It would be safer to disallow promotion to start and let folks figure out what they want to do later.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo, folks pointed out to me on Twitter https://twitter.com/ezyang/status/1665399428809629696 that there are multiple e4m3 variants; GraphCore has its own separate NaN variant. In JAX, these variants are distinguished, see https://github.com/jax-ml/ml_dtypes#specifications-of-implemented-floating-point-formats for their naming convention. Since I don't think we're expecting people to actually use these dtype names directly, it seems... OK to keep inline with the LLVM/JAX naming conventions?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I think this should be named Float8_e4m3fn to match other frameworks and since there is no single float8 standard yet, I commented on that yesterday (on the PR itself, not on the code).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't carefully audit these

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.
Checked with

import jax.numpy as np
vals = np.arange(0, 2**8, dtype=np.uint8).reshape(-1, 4)

print("e4m3 limits")
a = np.array([0x08, 0xFE, 0x7E, 0x20, 0x30, 0x7F, 0x01], dtype=np.uint8)
print(a.view(np.float8_e4m3fn))
print(vals.view(np.float8_e4m3fn))


print("e5m2 limits")
a = np.array([0x4, 0x7B, 0xFB, 0x34, 0x38, 0x7C, 0x01], dtype=np.uint8)
print(a.view(np.float8_e5m2))

print(vals.view(np.float8_e5m2))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to point out that this is completely unnecessary; as uint8 is byte sized you can just read it directly out of the pointer, no need to memcpy lol

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't forget to fix the missing eol

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did only a cursory inspection of this file

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code motion only right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. A few extra comments but Ed pointed out the main things.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return false; haha

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, done :)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.
Checked with

import jax.numpy as np
vals = np.arange(0, 2**8, dtype=np.uint8).reshape(-1, 4)

print("e4m3 limits")
a = np.array([0x08, 0xFE, 0x7E, 0x20, 0x30, 0x7F, 0x01], dtype=np.uint8)
print(a.view(np.float8_e4m3fn))
print(vals.view(np.float8_e4m3fn))


print("e5m2 limits")
a = np.array([0x4, 0x7B, 0xFB, 0x34, 0x38, 0x7C, 0x01], dtype=np.uint8)
print(a.view(np.float8_e5m2))

print(vals.view(np.float8_e5m2))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use parametrize from torch.testing._internal.common_utils

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't directly use pytest in tests but unittest.
You can take https://github.com/pytorch/pytorch/blob/a2049dbf8f921c925b70bd3714c1f5dbcd8200ed/test/quantization/core/experimental/test_linear.py as an example for the class creation and the __main__ clause.
Also all plain asserts can be replaced with sefl.assert*

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this is testing? Aren't 0s supposed to be the same?
Also once you use our test framework, you can do self.assertEqual(x, x8) !

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just first sanity test. I can remove it if it's too basic.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyone remembers, what was the original idea behind ALL_TYPES_AND_COMPLEX_AND?
But perhaps there is a time to add ABSOLUTELY_ALL_TYPES and ALL_TYPES_AND_ALL_FLOATS? (as I can't imagine want wants to add a support for F8_E5M2 but not for F8_E4M3)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clarify, this is not blocking the PR - we can figure out how to clean up these abstractions independently

@australopitek
Copy link
Copy Markdown
Contributor Author

@vkuzo ,
I have pushed next commit with changes after review and linters corrections. Few CI jobs still fail with two kinds of error. First is https://github.com/pytorch/pytorch/actions/runs/5508059225/jobs/10038844710?pr=104242, which I have a solution for, second is https://github.com/pytorch/pytorch/actions/runs/5508059225/jobs/10038847090?pr=104242, which I think is not related to my change.

Since I'm not very fluent with GitHub, is there a way to run these failing jobs locally, without pushing yet another commit?

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Jul 12, 2023

@australopitek you can pull docker from ghcr.io/pytorch/ci-image (most likely you'll need ASAN container) and run .ci/pytorch/build.sh inside of it.

Also, one can simply press re-run button to see what is going on.

It looks like all clang based builds are failing...

@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented Jul 13, 2023

@australopitek, this looks great! A Meta employee (myself or @malfet can help) will have to land this since there are Meta-only changes which are needed (unrelated to any discussion in this PR, just some internal code which needs to sync with all existing dtypes in PyTorch). Please let us know what a good handoff point would be, CI is mostly green and we are happy to help with the remaining issues if you'd like - let us know!

@australopitek
Copy link
Copy Markdown
Contributor Author

@vkuzo , @malfet , I'd be glad if you helped with these last bits. From CI logs I see there are mainly two kinds of issues left, but I'm off next two weeks due to holidays. I'll have limited access to this account, but will be able to log in occasionally.
How/when can we do the handoff?

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Jul 14, 2023

@australopitek I can probably take over the PR or open the new one and keep you as an author

@PaliC
Copy link
Copy Markdown
Contributor

PaliC commented Jul 20, 2023

@pytorchbot revert -m "breaks lint (run lintrunner and remerge)" -c ignoredsignal

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@australopitek your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Jul 20, 2023
…)"

This reverts commit a980413.

Reverted #104242 on behalf of https://github.com/PaliC due to breaks lint (run lintrunner and remerge) ([comment](#104242 (comment)))
@PaliC PaliC reopened this Jul 20, 2023
@PaliC
Copy link
Copy Markdown
Contributor

PaliC commented Jul 20, 2023

@pytorchbot merge -f "already landed internally"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jerryzh168 pushed a commit to jerryzh168/pytorch that referenced this pull request Jul 20, 2023
…ch#104242)"

This reverts commit a980413.

Reverted pytorch#104242 on behalf of https://github.com/PaliC due to breaks lint (run lintrunner and remerge) ([comment](pytorch#104242 (comment)))
jerryzh168 pushed a commit to jerryzh168/pytorch that referenced this pull request Jul 20, 2023
Proposal of two float8 variants - e5m2 and e4m3 - based on https://arxiv.org/pdf/2209.05433.pdf

Hide all Float8 operator implementations behind `#if !defined(C10_MOBILE)` guard to keep Android build size almost unchanged

TODO:
 - Refactor duplicated code
 - Cleanup unbalanced pragma pop in dtype utils
 - Add native implementation on the CUDA size

Co-authored-by: Nikita Shulga <nshulga@meta.com>

Pull Request resolved: pytorch#104242
Approved by: https://github.com/albanD
@albanD albanD mentioned this pull request Jul 20, 2023
pytorchmergebot pushed a commit that referenced this pull request Jul 20, 2023
Forward fix of the lint issues introduced by #104242
We are forward fixing as this PR contains Meta internal changes that would be tricky to revert smoothly.

Pull Request resolved: #105675
Approved by: https://github.com/jerryzh168, https://github.com/albanD, https://github.com/atalman
vkuzo added a commit to meta-pytorch/float8_experimental that referenced this pull request Jul 21, 2023
Summary:

Now that pytorch/pytorch#104242 landed, we can
stop emulation - this simplifies the code quite a bit.

Test Plan:

```
python float8_playground/test.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

if (f_bits >= fp8_max) {
// NaN - all exponent and mantissa bits set to 1
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I ask if PyTorch is only implementing the non-saturation mode for casts, as specified here ?

The current default behavior for ONNX is to saturate: https://github.com/onnx/onnx/blob/main/docs/Operators.md#attributes-7.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, currently it implements only non-saturation mode. And, in consequence, is not aligned with float8 implementation from FBGEMM https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h

@vkuzo, @albanD, was it intended to have saturation enabled by default in casting to float8 dtypes? If so, it needs to be added, because currently values out of float8 range are being casted to inf/nan. I can do it next week after vacation.

Also, should it be configurable via parameter of .to() operator or some other way?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment! During PR review we punted these questions to future PRs just to get the initial dtype in, I guess now it's the time to discuss them :)

I would make an educated guess that saturating by default would make sense if people are doing delayed scaling, but let me think on that / talk to some of the research teams using fp8.

For if/how to make this configurable...let me also ask around on what would make sense.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of consistency, when given integer inputs, we raise an error for too large numbers.
For floating point inputs, all our dtypes will return inf.

So I think inf is the right call here.

I think that a_half = a.clamp(max=torch.finfo(torch.half).max).to(dtype=torch.half) is a good way to get the behavior you want if needed.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD , there is some nuance here worth thinking about. In https://arxiv.org/pdf/2209.05433.pdf which these dtypes are expected to match, in Section 2, the authors say "Details of the heuristics to select
the scaling factors are beyond the scope of this paper, but the general idea is to choose a scaling factor such that the
maximum magnitude in the tensor becomes close to the maximum representable magnitude in the corresponding format. Values that overflow are then saturated to the maximum representable value." Specifically, if someone is using delayed scaling and there is an outlier, saturation is probably practically what they want.

NVIDIA also provides an enum for saturation behavior for fp8 (https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__FP8__MISC.html), which isn't there for wider floating point dtypes.

IMO a reasonable default behavior here should be what is practical for fp8 training with delayed scaling, and we'll have to work out the inconsistencies of how this is different with what we have for wider fp dtypes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by private API? Will it be exposed to the user on the python level? Or only on the level of hardware provider's PyTorch plugin?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be exposed to Python with a name preceded by an underscore, but there will initially be no BC guarantees.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the information.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vkuzo @albanD @australopitek ,

May I know the current status of enabling fp8 delayed scaling and saturation in PyTorch?

For fp8 delayed scaling, it is a per-tensor one. So, is the implementation similar with Quantized operators? It cannot just use CPU dispatch, right? Because scaling factor need to be stored in Tensor. Please correct me if anything wrong.

And you mentioned that fp8 training with delayed scaling, what about fp8 inference? In https://github.com/NVIDIA/TransformerEngine, fp8 inference is also with delayed scaling as training. Will you take this into consideration of fp8 enabling?

Thanks so much! Hope to hear from you guys!

cc @mingfeima

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yanbing-j , please feel free to message @gchanan on PyTorch slack for context on your question. I am on parental leave until early December but Greg can help on this while I'm out.

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Jul 26, 2023

@pytorchbot revert -m "breaks lint (run lintrunner and remerge)" -c ignoredsignal

Not sure I understand why you've added ignoredsignal category here, as lint on PR CI was green. Perhaps landrace would be a more appropriate category.

* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
Copy link
Copy Markdown
Collaborator

@yanbing-j yanbing-j Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @australopitek @vkuzo , may I know where is the implementation of conversion between fp8 and fp32 from? Is there any reference materials? Thanks!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanbing-j it was based on the float16 datatype implementation https://github.com/pytorch/pytorch/blob/main/c10/util/Half.h with alignment to float8 bits configuration.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanbing-j Also see the casting rules tabulated here: https://onnx.ai/onnx/technical/float8.html

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanbing-j it was based on the float16 datatype implementation https://github.com/pytorch/pytorch/blob/main/c10/util/Half.h with alignment to float8 bits configuration.

What about the fp8_max? why we need to support special values here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanbing-j Also see the casting rules tabulated here: https://onnx.ai/onnx/technical/float8.html

Thanks. What is the value of FLT_MAX, does this support special values (in Section 3.1 in https://arxiv.org/pdf/2209.05433.pdf)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanbing-j,
Part with fp8_max was probably inspired by Eigen implementation https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Core/arch/Default/Half.h#L625. Now I don't remember why I chose this one, but I guess it's analogous to PyTorch's implementation.

Special values are implemented according to mentioned paper and they're tested in https://github.com/pytorch/pytorch/blob/main/test/quantization/core/experimental/test_float8.py#L117 for both dtypes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@australopitek Thank you so much. It really helps.

pytorchmergebot pushed a commit that referenced this pull request Nov 15, 2023
This PR relates to the feature in [this feature submission](https://docs.google.com/document/d/1pF2T1xz54IPg1jG7FhykbrpbcJZVelQw0v8vBaoLkfs/edit). It has been based on #104242 which adds similar float8 types.

These new types added in this PR are described in the paper at https://arxiv.org/abs/2206.02915. A brief description and comparison of the types with other float8 types can be also found in the [OpenXLA RFC](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md).

Pull Request resolved: #107586
Approved by: https://github.com/seemethere, https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) NNC open source release notes: linalg_frontend release notes category release notes: quantization release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.