Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning. For testing, we upcast to fp32 before calling the reference function. [ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning. For testing, we upcast to fp32 before calling the reference function. [ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning. For testing, we upcast to fp32 before calling the reference function. [ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning. For testing, we upcast to fp32 before calling the reference function. [ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning. For testing, we upcast to fp32 before calling the reference function. [ghstack-poisoned]
[ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning. For testing, we upcast to fp32 before calling the reference function. We increase the tolerance to 1e-2 for bf16 inputs because of a difference in casting calculations between python's `x.to(torch.bfloat16)` and cpp's `x.to(at::kBFloat16)` (after comparing intermediate tensors, we found that the numerics diverge after the final casting). We don't explicitly cast in the CPP op but rather let autograd/optimizer handle it. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165362
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6d643da with merge base fbe0d20 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
| */ | ||
| float scale_val = scale[0].item<float>(); | ||
|
|
||
| bool is_bfloat16 = (X.scalar_type() == at::kBFloat16); |
There was a problem hiding this comment.
So we shouldn't cast fp16
There was a problem hiding this comment.
we enabled bf16 support for per_tensor alongside per_channel in this pr #165325, so if we want to enable fp16 we can do it in a separate pr for both these ops?
| auto dScale = dScale_vec.sum().unsqueeze(0).to(scale_.device()); | ||
| auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point_.device()); | ||
|
|
||
| return std::make_tuple(dX, dScale, dZeroPoint); |
There was a problem hiding this comment.
All these should be std::move into make tuple btw
There was a problem hiding this comment.
should we make this change in a follow up PR since this PR doesn't actually touch this code?
|
@liangel-02 has imported this pull request. If you are a Meta employee, you can view this in D84639869. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Adding bf16 for the backward pass of `torch._fake_quantize_learnable_per_tensor_affine()`. Note that for testing, we modified the seed to avoid increasing tolerance due to cases where difference in Python vs CPP downcasting causes tensor mismatches. (e.g. 27.87704 vs 27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op) Pull Request resolved: pytorch#165362 Approved by: https://github.com/andrewor14
Adding bf16 for the backward pass of `torch._fake_quantize_learnable_per_tensor_affine()`. Note that for testing, we modified the seed to avoid increasing tolerance due to cases where difference in Python vs CPP downcasting causes tensor mismatches. (e.g. 27.87704 vs 27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op) Pull Request resolved: pytorch#165362 Approved by: https://github.com/andrewor14
Adding bf16 for the backward pass of
torch._fake_quantize_learnable_per_tensor_affine().Note that for testing, we modified the seed to avoid increasing tolerance due to cases where difference in Python vs CPP downcasting causes tensor mismatches. (e.g. 27.87704 vs 27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op)
Stack from ghstack (oldest at bottom):