Skip to content

bf16 support for per tensor backward#165362

Closed
liangel-02 wants to merge 20 commits intomainfrom
gh/liangel-02/2/head
Closed

bf16 support for per tensor backward#165362
liangel-02 wants to merge 20 commits intomainfrom
gh/liangel-02/2/head

Conversation

@liangel-02
Copy link
Contributor

@liangel-02 liangel-02 commented Oct 13, 2025

Adding bf16 for the backward pass of torch._fake_quantize_learnable_per_tensor_affine().

Note that for testing, we modified the seed to avoid increasing tolerance due to cases where difference in Python vs CPP downcasting causes tensor mismatches. (e.g. 27.87704 vs 27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op)

Stack from ghstack (oldest at bottom):

Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function.




[ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function.




[ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function.




[ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function.




[ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function.




[ghstack-poisoned]
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function. We increase the tolerance to 1e-2 for bf16 inputs because of a difference in casting calculations between python's `x.to(torch.bfloat16)` and cpp's `x.to(at::kBFloat16)` (after comparing intermediate tensors, we found that the numerics diverge after the final casting). We don't explicitly cast in the CPP op but rather let autograd/optimizer handle it.





[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165362

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6d643da with merge base fbe0d20 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: quantization release notes category label Oct 13, 2025
liangel-02 added a commit that referenced this pull request Oct 13, 2025
ghstack-source-id: d03e828
Pull Request resolved: #165362
liangel-02 added a commit that referenced this pull request Oct 13, 2025
ghstack-source-id: fb073e2
Pull Request resolved: #165362
liangel-02 added a commit that referenced this pull request Oct 13, 2025
ghstack-source-id: 7a6f0c0
Pull Request resolved: #165362
@liangel-02 liangel-02 marked this pull request as draft October 13, 2025 21:10
liangel-02 added a commit that referenced this pull request Oct 14, 2025
ghstack-source-id: 44b30c2
Pull Request resolved: #165362
@liangel-02 liangel-02 marked this pull request as ready for review October 14, 2025 14:31
@liangel-02 liangel-02 requested a review from andrewor14 October 14, 2025 14:32
liangel-02 added a commit that referenced this pull request Oct 14, 2025
ghstack-source-id: b9eb103
Pull Request resolved: #165362
@liangel-02 liangel-02 changed the base branch from gh/liangel-02/2/base to main October 14, 2025 14:33
@liangel-02 liangel-02 changed the base branch from main to gh/liangel-02/2/base October 14, 2025 14:36
liangel-02 added a commit that referenced this pull request Oct 14, 2025
ghstack-source-id: 9776242
Pull Request resolved: #165362
*/
float scale_val = scale[0].item<float>();

bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we shouldn't cast fp16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we enabled bf16 support for per_tensor alongside per_channel in this pr #165325, so if we want to enable fp16 we can do it in a separate pr for both these ops?

auto dScale = dScale_vec.sum().unsqueeze(0).to(scale_.device());
auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point_.device());

return std::make_tuple(dX, dScale, dZeroPoint);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these should be std::move into make tuple btw

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make this change in a follow up PR since this PR doesn't actually touch this code?

@liangel-02 liangel-02 changed the base branch from gh/liangel-02/2/base to main October 14, 2025 20:05
@meta-codesync
Copy link

meta-codesync bot commented Oct 14, 2025

@liangel-02 has imported this pull request. If you are a Meta employee, you can view this in D84639869.

@liangel-02
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 16, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Adding bf16 for the backward pass of `torch._fake_quantize_learnable_per_tensor_affine()`.

Note that for testing, we modified the seed to avoid increasing tolerance due to cases where difference in Python vs CPP downcasting causes tensor mismatches. (e.g. 27.87704 vs  27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op)

Pull Request resolved: pytorch#165362
Approved by: https://github.com/andrewor14
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
Adding bf16 for the backward pass of `torch._fake_quantize_learnable_per_tensor_affine()`.

Note that for testing, we modified the seed to avoid increasing tolerance due to cases where difference in Python vs CPP downcasting causes tensor mismatches. (e.g. 27.87704 vs  27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op)

Pull Request resolved: pytorch#165362
Approved by: https://github.com/andrewor14
@github-actions github-actions bot deleted the gh/liangel-02/2/head branch November 16, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: quantization release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants