Add torch.float8_e5m2 and torch.float8_e4m3 data types#104242
Add torch.float8_e5m2 and torch.float8_e4m3 data types#104242australopitek wants to merge 16 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104242
Note: Links to docs will display an error until the docs builds have been completed. ✅ 1 Unrelated FailureAs of commit a4fa283: UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
tangentially related: #69364 |
vkuzo
left a comment
There was a problem hiding this comment.
this looks great and it's a lot more complete and battle tested than the previous time we tried this! I had a couple of initial questions inline.
@australopitek , as @albanD mentioned we finally agreed to add these to core. We owe you a more detailed review next week, but at a high level, would you be up for getting this PR to green CI / ready to land, or do you prefer that we take it as is and finish on our side?
c10/core/ScalarType.h
Outdated
There was a problem hiding this comment.
is this intended to be used with scale? The current autocast design is stateless, so scaled float8 doesn't cleanly fit in. If this is intended to be used without scale, I'd love to learn on how one might achieve competitive accuracy.
There was a problem hiding this comment.
You're right, this shouldn't be there. We were exploring possibilities to use fp8 with autocast, but finally decided to use custom solutions.
In fact, I see that AT_FORAUTOCAST_SCALAR_TYPES is not use anywhere, the only usage were removed quite long time ago. I will clean this during further work.
c10/util/Float8_e4m3-inl.h
Outdated
There was a problem hiding this comment.
just curious, are you folks using these unscaled arithmetic operations on float8 for real use cases, or is this just for completeness?
There was a problem hiding this comment.
This is just for completeness and simple tests. We use scaled float8 operations.
c10/util/Float8_e4m3.h
Outdated
There was a problem hiding this comment.
nice! I didn't read line by line yet, but high level question - what's your level of confidence on the correctness of these functions? Have these been matched up with NVIDIA's casts or https://github.com/pytorch/FBGEMM/blob/277677039bae25b2570a73013b03bfaa9d2a523e/fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h ?
There was a problem hiding this comment.
It was quite extensively tested against Gaudi2, which supports these variants natively. However, in next commits, I'll add comparative tests for all bits configurations against some reference implementation.
Can you guarantee that fbgemm implementation is fully aligned with the specification? Also, do you know if there is some open fp8_e5m2 reference implementation I could use for testing?
There was a problem hiding this comment.
the fbgemm implementation has been used extensively at Meta but it hasn't been matched up against real hardware afaik.
Your test cases look good, I'd say if future changes on numerics are needed that can be done in future PRs and not block this one. LGTM.
|
@vkuzo , |
vkuzo
left a comment
There was a problem hiding this comment.
Overall this looks great to me and this fits into how we want to extend the ecosystem around float8. Thanks @australopitek for getting this PR to a really good state! I only have a couple of nit comments:
- can we match MLIR (https://mlir.llvm.org/doxygen/classmlir_1_1FloatType.html) and XLA and go with
Float8_E4M3FNinstead ofFloat8_E4M3? - can we also get a review from @ezyang or @albanD on the framework parts of the PR?
Using your current implementation of the casts as an initial reference LGTM, we can adjust as needed in future PRs.
|
@vkuzo , |
aten/src/ATen/NumericUtils.h
Outdated
There was a problem hiding this comment.
We probably will need to replace this with a primitive at some point as this is highly unlikely to be optimal.
There was a problem hiding this comment.
Changed to helper method
inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
return (x & 0b01111111) == 0b01111111;
}
c10/core/ScalarType.h
Outdated
There was a problem hiding this comment.
b8/h8 are funny names to call these lol.
There was a problem hiding this comment.
Are there funny enough to be changed, or can I leave there?
c10/core/ScalarType.h
Outdated
There was a problem hiding this comment.
This says that when you combine Float8_e5m2 and Float8_e4m3, it promotes to Float8_e4m3. This seems suspicious to me, as it's inconsistent with float16 and bfloat16 promotion, which promotes to float32 per the RFC at #43049 Additionally, JAX currently doesn't support promotion between these two types. It would be safer to disallow promotion to start and let folks figure out what they want to do later.
c10/core/ScalarType.h
Outdated
There was a problem hiding this comment.
@vkuzo, folks pointed out to me on Twitter https://twitter.com/ezyang/status/1665399428809629696 that there are multiple e4m3 variants; GraphCore has its own separate NaN variant. In JAX, these variants are distinguished, see https://github.com/jax-ml/ml_dtypes#specifications-of-implemented-floating-point-formats for their naming convention. Since I don't think we're expecting people to actually use these dtype names directly, it seems... OK to keep inline with the LLVM/JAX naming conventions?
There was a problem hiding this comment.
I agree, I think this should be named Float8_e4m3fn to match other frameworks and since there is no single float8 standard yet, I commented on that yesterday (on the PR itself, not on the code).
c10/util/Float8_e4m3-inl.h
Outdated
There was a problem hiding this comment.
I didn't carefully audit these
There was a problem hiding this comment.
Looks good.
Checked with
import jax.numpy as np
vals = np.arange(0, 2**8, dtype=np.uint8).reshape(-1, 4)
print("e4m3 limits")
a = np.array([0x08, 0xFE, 0x7E, 0x20, 0x30, 0x7F, 0x01], dtype=np.uint8)
print(a.view(np.float8_e4m3fn))
print(vals.view(np.float8_e4m3fn))
print("e5m2 limits")
a = np.array([0x4, 0x7B, 0xFB, 0x34, 0x38, 0x7C, 0x01], dtype=np.uint8)
print(a.view(np.float8_e5m2))
print(vals.view(np.float8_e5m2))
torch/csrc/utils/byte_order.cpp
Outdated
There was a problem hiding this comment.
I just want to point out that this is completely unnecessary; as uint8 is byte sized you can just read it directly out of the pointer, no need to memcpy lol
There was a problem hiding this comment.
don't forget to fix the missing eol
c10/util/Float8_e5m2-inl.h
Outdated
There was a problem hiding this comment.
I did only a cursory inspection of this file
c10/util/floating_point_utils.h
Outdated
albanD
left a comment
There was a problem hiding this comment.
Looks great. A few extra comments but Ed pointed out the main things.
aten/src/ATen/NumericUtils.h
Outdated
There was a problem hiding this comment.
Nice catch, done :)
c10/util/Float8_e5m2-inl.h
Outdated
c10/util/Float8_e4m3-inl.h
Outdated
There was a problem hiding this comment.
Looks good.
Checked with
import jax.numpy as np
vals = np.arange(0, 2**8, dtype=np.uint8).reshape(-1, 4)
print("e4m3 limits")
a = np.array([0x08, 0xFE, 0x7E, 0x20, 0x30, 0x7F, 0x01], dtype=np.uint8)
print(a.view(np.float8_e4m3fn))
print(vals.view(np.float8_e4m3fn))
print("e5m2 limits")
a = np.array([0x4, 0x7B, 0xFB, 0x34, 0x38, 0x7C, 0x01], dtype=np.uint8)
print(a.view(np.float8_e5m2))
print(vals.view(np.float8_e5m2))There was a problem hiding this comment.
You can use parametrize from torch.testing._internal.common_utils
There was a problem hiding this comment.
We don't directly use pytest in tests but unittest.
You can take https://github.com/pytorch/pytorch/blob/a2049dbf8f921c925b70bd3714c1f5dbcd8200ed/test/quantization/core/experimental/test_linear.py as an example for the class creation and the __main__ clause.
Also all plain asserts can be replaced with sefl.assert*
There was a problem hiding this comment.
Not sure what this is testing? Aren't 0s supposed to be the same?
Also once you use our test framework, you can do self.assertEqual(x, x8) !
There was a problem hiding this comment.
Just first sanity test. I can remove it if it's too basic.
aten/src/ATen/Dispatch.h
Outdated
There was a problem hiding this comment.
Anyone remembers, what was the original idea behind ALL_TYPES_AND_COMPLEX_AND?
But perhaps there is a time to add ABSOLUTELY_ALL_TYPES and ALL_TYPES_AND_ALL_FLOATS? (as I can't imagine want wants to add a support for F8_E5M2 but not for F8_E4M3)
There was a problem hiding this comment.
to clarify, this is not blocking the PR - we can figure out how to clean up these abstractions independently
|
@vkuzo , Since I'm not very fluent with GitHub, is there a way to run these failing jobs locally, without pushing yet another commit? |
|
@australopitek you can pull docker from ghcr.io/pytorch/ci-image (most likely you'll need ASAN container) and run Also, one can simply press re-run button to see what is going on. It looks like all |
|
@australopitek, this looks great! A Meta employee (myself or @malfet can help) will have to land this since there are Meta-only changes which are needed (unrelated to any discussion in this PR, just some internal code which needs to sync with all existing dtypes in PyTorch). Please let us know what a good handoff point would be, CI is mostly green and we are happy to help with the remaining issues if you'd like - let us know! |
|
@australopitek I can probably take over the PR or open the new one and keep you as an author |
|
@pytorchbot revert -m "breaks lint (run lintrunner and remerge)" -c ignoredsignal |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@australopitek your PR has been successfully reverted. |
…)" This reverts commit a980413. Reverted #104242 on behalf of https://github.com/PaliC due to breaks lint (run lintrunner and remerge) ([comment](#104242 (comment)))
|
@pytorchbot merge -f "already landed internally" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ch#104242)" This reverts commit a980413. Reverted pytorch#104242 on behalf of https://github.com/PaliC due to breaks lint (run lintrunner and remerge) ([comment](pytorch#104242 (comment)))
Proposal of two float8 variants - e5m2 and e4m3 - based on https://arxiv.org/pdf/2209.05433.pdf Hide all Float8 operator implementations behind `#if !defined(C10_MOBILE)` guard to keep Android build size almost unchanged TODO: - Refactor duplicated code - Cleanup unbalanced pragma pop in dtype utils - Add native implementation on the CUDA size Co-authored-by: Nikita Shulga <nshulga@meta.com> Pull Request resolved: pytorch#104242 Approved by: https://github.com/albanD
Forward fix of the lint issues introduced by #104242 We are forward fixing as this PR contains Meta internal changes that would be tricky to revert smoothly. Pull Request resolved: #105675 Approved by: https://github.com/jerryzh168, https://github.com/albanD, https://github.com/atalman
Summary: Now that pytorch/pytorch#104242 landed, we can stop emulation - this simplifies the code quite a bit. Test Plan: ``` python float8_playground/test.py ``` Reviewers: Subscribers: Tasks: Tags:
|
|
||
| if (f_bits >= fp8_max) { | ||
| // NaN - all exponent and mantissa bits set to 1 | ||
| result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C); |
There was a problem hiding this comment.
Can I ask if PyTorch is only implementing the non-saturation mode for casts, as specified here ?
The current default behavior for ONNX is to saturate: https://github.com/onnx/onnx/blob/main/docs/Operators.md#attributes-7.
There was a problem hiding this comment.
You're right, currently it implements only non-saturation mode. And, in consequence, is not aligned with float8 implementation from FBGEMM https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h
@vkuzo, @albanD, was it intended to have saturation enabled by default in casting to float8 dtypes? If so, it needs to be added, because currently values out of float8 range are being casted to inf/nan. I can do it next week after vacation.
Also, should it be configurable via parameter of .to() operator or some other way?
There was a problem hiding this comment.
Thanks for the comment! During PR review we punted these questions to future PRs just to get the initial dtype in, I guess now it's the time to discuss them :)
I would make an educated guess that saturating by default would make sense if people are doing delayed scaling, but let me think on that / talk to some of the research teams using fp8.
For if/how to make this configurable...let me also ask around on what would make sense.
There was a problem hiding this comment.
In terms of consistency, when given integer inputs, we raise an error for too large numbers.
For floating point inputs, all our dtypes will return inf.
So I think inf is the right call here.
I think that a_half = a.clamp(max=torch.finfo(torch.half).max).to(dtype=torch.half) is a good way to get the behavior you want if needed.
There was a problem hiding this comment.
@albanD , there is some nuance here worth thinking about. In https://arxiv.org/pdf/2209.05433.pdf which these dtypes are expected to match, in Section 2, the authors say "Details of the heuristics to select
the scaling factors are beyond the scope of this paper, but the general idea is to choose a scaling factor such that the
maximum magnitude in the tensor becomes close to the maximum representable magnitude in the corresponding format. Values that overflow are then saturated to the maximum representable value." Specifically, if someone is using delayed scaling and there is an outlier, saturation is probably practically what they want.
NVIDIA also provides an enum for saturation behavior for fp8 (https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__FP8__MISC.html), which isn't there for wider floating point dtypes.
IMO a reasonable default behavior here should be what is practical for fp8 training with delayed scaling, and we'll have to work out the inconsistencies of how this is different with what we have for wider fp dtypes.
There was a problem hiding this comment.
What do you mean by private API? Will it be exposed to the user on the python level? Or only on the level of hardware provider's PyTorch plugin?
There was a problem hiding this comment.
It will be exposed to Python with a name preceded by an underscore, but there will initially be no BC guarantees.
There was a problem hiding this comment.
Thank you for the information.
There was a problem hiding this comment.
Hi @vkuzo @albanD @australopitek ,
May I know the current status of enabling fp8 delayed scaling and saturation in PyTorch?
For fp8 delayed scaling, it is a per-tensor one. So, is the implementation similar with Quantized operators? It cannot just use CPU dispatch, right? Because scaling factor need to be stored in Tensor. Please correct me if anything wrong.
And you mentioned that fp8 training with delayed scaling, what about fp8 inference? In https://github.com/NVIDIA/TransformerEngine, fp8 inference is also with delayed scaling as training. Will you take this into consideration of fp8 enabling?
Thanks so much! Hope to hear from you guys!
cc @mingfeima
There was a problem hiding this comment.
Hi @yanbing-j , please feel free to message @gchanan on PyTorch slack for context on your question. I am on parental leave until early December but Greg can help on this while I'm out.
Not sure I understand why you've added |
| * Convert a 32-bit floating-point number in IEEE single-precision format to a | ||
| * 8-bit floating-point number in fp8 E4M3FN format, in bit representation. | ||
| */ | ||
| inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) { |
There was a problem hiding this comment.
Hi @australopitek @vkuzo , may I know where is the implementation of conversion between fp8 and fp32 from? Is there any reference materials? Thanks!
There was a problem hiding this comment.
@yanbing-j it was based on the float16 datatype implementation https://github.com/pytorch/pytorch/blob/main/c10/util/Half.h with alignment to float8 bits configuration.
There was a problem hiding this comment.
@yanbing-j Also see the casting rules tabulated here: https://onnx.ai/onnx/technical/float8.html
There was a problem hiding this comment.
@yanbing-j it was based on the float16 datatype implementation https://github.com/pytorch/pytorch/blob/main/c10/util/Half.h with alignment to float8 bits configuration.
What about the fp8_max? why we need to support special values here?
There was a problem hiding this comment.
@yanbing-j Also see the casting rules tabulated here: https://onnx.ai/onnx/technical/float8.html
Thanks. What is the value of FLT_MAX, does this support special values (in Section 3.1 in https://arxiv.org/pdf/2209.05433.pdf)?
There was a problem hiding this comment.
@yanbing-j,
Part with fp8_max was probably inspired by Eigen implementation https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Core/arch/Default/Half.h#L625. Now I don't remember why I chose this one, but I guess it's analogous to PyTorch's implementation.
Special values are implemented according to mentioned paper and they're tested in https://github.com/pytorch/pytorch/blob/main/test/quantization/core/experimental/test_float8.py#L117 for both dtypes.
There was a problem hiding this comment.
@australopitek Thank you so much. It really helps.
This PR relates to the feature in [this feature submission](https://docs.google.com/document/d/1pF2T1xz54IPg1jG7FhykbrpbcJZVelQw0v8vBaoLkfs/edit). It has been based on #104242 which adds similar float8 types. These new types added in this PR are described in the paper at https://arxiv.org/abs/2206.02915. A brief description and comparison of the types with other float8 types can be also found in the [OpenXLA RFC](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md). Pull Request resolved: #107586 Approved by: https://github.com/seemethere, https://github.com/malfet
Proposal of two float8 variants - e5m2 and e4m3 - based on https://arxiv.org/pdf/2209.05433.pdf
Hide all Float8 operator implementations behind
#if !defined(C10_MOBILE)guard to keep Android build size almost unchangedTODO:
Co-authored-by: Nikita Shulga nshulga@meta.com
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @EikanWang