Skip to content

[Pytorch] Add option to CPU Blas GEMM to avoid output downcast#154012

Closed
cyrusd98 wants to merge 1 commit intopytorch:mainfrom
cyrusd98:export-D75023858
Closed

[Pytorch] Add option to CPU Blas GEMM to avoid output downcast#154012
cyrusd98 wants to merge 1 commit intopytorch:mainfrom
cyrusd98:export-D75023858

Conversation

@cyrusd98
Copy link
Contributor

@cyrusd98 cyrusd98 commented May 21, 2025

Summary:
Dot product for a single output element consists of 3 steps (both input vectors have elements of type scalar_t):

  1. elementwise vector multiply (scalar_t x scalar_t -> opmath_t)
  2. vector reduction to a scalar value (opmath_t -> opmath_t)
  3. optional downcast if opmath_t != out_t

The current blas kernel performs steps 1 and 2 correctly, but for step 3, it will always downcast to scalar_t even when opmath_t == output_t (and then do an upcast back to output_t), which results in precision loss. This diff fixes the precision loss in the BlasKernel

Test Plan: Attention CI passes

Differential Revision: D75023858

topic: not user facing

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168

@pytorch-bot
Copy link

pytorch-bot bot commented May 21, 2025

This appears to be a diff that was exported from phabricator, but the PR author does not have sufficient permissions to run CI. @cyrusd98, please do step 2 of internal wiki to get write access so you do not need to get CI approvals in the future. If you think this is a mistake, please contact the Pytorch Dev Infra team.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented May 21, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: cyrusd98 / name: Cyrus Daruwala (3199dc9)

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label May 21, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented May 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154012

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 3199dc9 with merge base 86a1603 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75023858

@drisspg drisspg added intel This tag is for PR from Intel module: numerical-stability Problems related to numerical stability of operations release notes: nn release notes category module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels May 21, 2025
@drisspg drisspg requested review from CaoE, Valentine233 and aditew01 May 21, 2025 19:08
Copy link
Collaborator

@Valentine233 Valentine233 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix.

@Valentine233 Valentine233 requested a review from mingfeima May 22, 2025 01:41
Copy link
Collaborator

@aditew01 aditew01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. LGTM!

@CaoE
Copy link
Collaborator

CaoE commented May 23, 2025

Thanks for the fix.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75023858

…ch#154012)

Summary:
Pull Request resolved: pytorch#154012

Dot product for a single output element consists of 3 steps (both input vectors have elements of type scalar_t):
1. elementwise vector multiply (scalar_t x scalar_t -> opmath_t)
2. vector reduction to a scalar value (opmath_t -> opmath_t)
3. optional downcast if opmath_t != out_t

The current blas kernel performs steps 1 and 2 correctly, but for step 3, it will always downcast to scalar_t even when opmath_t == output_t (and then do an upcast back to output_t), which results in precision loss. This diff fixes the precision loss in the BlasKernel

Test Plan: Attention CI passes

Differential Revision: D75023858
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75023858

@cyrusd98 cyrusd98 force-pushed the export-D75023858 branch from d57c495 to 3199dc9 Compare May 23, 2025 22:09
@cyrusd98
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 24, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

malfet added a commit that referenced this pull request Sep 2, 2025
Followup after #154012

Since the introduction of `gemm_no_downcast_stub` it's no longer
necessary to allocate temporary array and than manually implement the
`beta` logic in the codebase

ghstack-source-id: 47d17e8
Pull Request resolved: #162001
pytorchmergebot pushed a commit that referenced this pull request Sep 3, 2025
Followup after #154012

Fixes CPU part of #160841

Pull Request resolved: #161999
Approved by: https://github.com/drisspg
pytorchmergebot pushed a commit that referenced this pull request Sep 3, 2025
Followup after #154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: #162001
Approved by: https://github.com/drisspg
ghstack dependencies: #161999
pytorchmergebot pushed a commit that referenced this pull request Sep 4, 2025
Followup after #154012

Fixes CPU part of #160841

Pull Request resolved: #161999
Approved by: https://github.com/drisspg
pytorchmergebot pushed a commit that referenced this pull request Sep 5, 2025
Followup after #154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: #162001
Approved by: https://github.com/drisspg
ghstack dependencies: #161999
@jeanschmidt
Copy link
Contributor

@pytorchbot revert -m "Breaks ADS internal tests, see D81845017" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 154012 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit cfbd99fdfd7282c8969f123d5819a47d408ce78a returned non-zero exit code 1

Auto-merging aten/src/ATen/native/CPUBlas.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/native/CPUBlas.cpp
Auto-merging aten/src/ATen/native/CPUBlas.h
Auto-merging aten/src/ATen/native/cpu/BlasKernel.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/native/cpu/BlasKernel.cpp
error: could not revert cfbd99fdfd7... [Pytorch] Add option to CPU Blas GEMM to avoid output downcast (#154012)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@jeanschmidt
Copy link
Contributor

forget about this try revert, i followed the wrong link

daisyden pushed a commit to daisyden/pytorch that referenced this pull request Sep 8, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
pytorchmergebot pushed a commit that referenced this pull request Sep 11, 2025
Followup after #154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: #162001
Approved by: https://github.com/drisspg
ghstack dependencies: #161999
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Followup after pytorch#154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: pytorch#162001
Approved by: https://github.com/drisspg
ghstack dependencies: pytorch#161999
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request fb-exported intel This tag is for PR from Intel Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: numerical-stability Problems related to numerical stability of operations module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion release notes: nn release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants