Skip to content

Accelerate SDPA on Arm CPUs: Unroll exp_sum and max_mul kernels#177009

Open
fadara01 wants to merge 4 commits intogh/fadara01/11/basefrom
gh/fadara01/11/head
Open

Accelerate SDPA on Arm CPUs: Unroll exp_sum and max_mul kernels#177009
fadara01 wants to merge 4 commits intogh/fadara01/11/basefrom
gh/fadara01/11/head

Conversation

@fadara01
Copy link
Copy Markdown
Collaborator

@fadara01 fadara01 commented Mar 10, 2026

Stack from ghstack (oldest at bottom):

We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64), 40% of our time in scaled-dot-production-attention is spent in the _exp_reduce_sum_fusion_kernel.
Some of that overhead was addressed by #176881 which introduces a faster exp.
We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels
for better ILP.

The unrolling pattern used here is already present in other SDPA helper kernels like _scale_attn_mask_fusion_kernel

While using VectorizedN for unrolling, I noticed that:

  • we don't have a fast path to convert VectorizedN<float> to VectorizedN<bfloat16> for NEON, so I added that. We already have that for SVE.
  • we don't have a fast path to short-circuit identity conversions for VectorizedN (e.g. VectorizedN<float> to VectorizedN<float>), so I added that too

Performance

Using this SDPA benchmark, here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

B Hq Hkv Lq Lk D causal gqa Speedup from #176881 vs current Speedup from #176881 and this PR vs current
1 32 8 2048 2048 128 True True +9.48% +14.91%
1 32 8 1 2048 128 False True -1.42% -2.79%
1 16 16 6400 6400 80 False False +5.18% +11.60%
1 20 20 1500 1500 64 False False +6.63% +11.80%
8 20 20 1500 1500 64 False False +9.31% +17.12%

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 10, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177009

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 253a1ea with merge base 0951602 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Mar 10, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 10, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@fadara01
Copy link
Copy Markdown
Collaborator Author

@pytorchbot label "topic: not user facing"

[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Copy Markdown
Collaborator

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nits

sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 13, 2026
[ghstack-poisoned]
@fadara01
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 16, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yangw-dev
Copy link
Copy Markdown
Contributor

@ezyang do you mind take a look at this? since you are auto assigned to this change

@aditew01
Copy link
Copy Markdown
Collaborator

maybe we need to gaurd the introduced logic under !defined(C10_MOBILE)

EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…rch#177009)

We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64),  40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`.
Some of that overhead was addressed by pytorch#176881 which introduces a faster exp.
We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels
 for better ILP.

The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53)

While using `VectorizedN` for unrolling, I noticed that:
- we don't have a fast path to convert `VectorizedN<float>` to  `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE.
- we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from pytorch#176881  vs current | Speedup from pytorch#176881 and this PR vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42%  | -2.79% |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% |

Pull Request resolved: pytorch#177009
Approved by: https://github.com/jgong5, https://github.com/Skylion007
ghstack dependencies: pytorch#176881
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ls (pytorch#177009)"

This reverts commit e5c56db.

Reverted pytorch#177009 on behalf of https://github.com/yangw-dev due to sorry it seems this breaks internal tests xplat/caffe2/aten/src/ATen/cpu/vec/vec128/vec128_convert.h:390:17: error: use of undeclared identifier 'convert_float_half', D96767295. please reach out to meta internal folks to resolve this ([comment](pytorch#177009 (comment)))
@fadara01
Copy link
Copy Markdown
Collaborator Author

@malfet - has meta's internal PyTorch been modified to work with this PR?
Are we okay to re-land this?

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Mar 30, 2026

@fadara01 there are no Metal-internal PR associated with this one, but let me have a look at what is going on

AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…rch#177009)

We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64),  40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`.
Some of that overhead was addressed by pytorch#176881 which introduces a faster exp.
We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels
 for better ILP.

The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53)

While using `VectorizedN` for unrolling, I noticed that:
- we don't have a fast path to convert `VectorizedN<float>` to  `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE.
- we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from pytorch#176881  vs current | Speedup from pytorch#176881 and this PR vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42%  | -2.79% |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% |

Pull Request resolved: pytorch#177009
Approved by: https://github.com/jgong5, https://github.com/Skylion007
ghstack dependencies: pytorch#176881
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…ls (pytorch#177009)"

This reverts commit e5c56db.

Reverted pytorch#177009 on behalf of https://github.com/yangw-dev due to sorry it seems this breaks internal tests xplat/caffe2/aten/src/ATen/cpu/vec/vec128/vec128_convert.h:390:17: error: use of undeclared identifier 'convert_float_half', D96767295. please reach out to meta internal folks to resolve this ([comment](pytorch#177009 (comment)))
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Apr 7, 2026

⏺ Bash(git show main:aten/src/ATen/cpu/vec/vec128/vec128_convert.h | head -10)
  ⎿  PreToolUse:Bash hook error
  ⎿  #pragma once
     #include <ATen/cpu/vec/vec_base.h>
     #include <ATen/cpu/vec/vec_convert.h>
     … +7 lines (ctrl+o to expand)

⏺ OK so the original guard was defined(__aarch64__) &&
  !defined(CPU_CAPABILITY_SVE256) — no C10_MOBILE check. That means the existing
  code in this block was available on mobile. Adding !defined(C10_MOBILE) to the
  whole block would be a regression for mobile.

  I should revert my change and instead wrap just the two new specializations.

⏺ Update(aten/src/ATen/cpu/vec/vec128/vec128_convert.h)
  ⎿  PreToolUse:Edit hook error
  ⎿  Added 1 line, removed 1 line
       4
       5  namespace at::vec {
       6  inline namespace CPU_CAPABILITY {
       7 -#if !defined(C10_MOBILE) && defined(__aarch64__) && !defined(CPU
         -_CAPABILITY_SVE256)
       7 +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
       8
       9  // Enable auto-vectorization for clang-17+
      10  // GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001

⏺ Update(aten/src/ATen/cpu/vec/vec128/vec128_convert.h)
  ⎿  PreToolUse:Edit hook error
  ⎿  Added 2 lines
      373    }
      374  };
      375
      376 +#if !defined(C10_MOBILE)
      377  template <>
      378  struct VecConvert<BFloat16, 1, float, 2> {
      379    static inline VectorizedN<BFloat16, 1> apply(
     ...
      392      return result;
      393    }
      394  };
      395 +#endif // !defined(C10_MOBILE)
      396
      397  #endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE25
           6)
      398  } // namespace CPU_CAPABILITY

⏺ Done. Wrapped the two new VecConvert specializations in #if
  !defined(C10_MOBILE) since convert_float_half and convert_float_bfloat16 are
  defined in headers that are guarded by !defined(C10_MOBILE). The rest of the
  existing code in the block remains available on mobile as before.

@aditew01
Copy link
Copy Markdown
Collaborator

aditew01 commented Apr 7, 2026

@ezyang does this PR need any additional changes, based on

Done. Wrapped the two new VecConvert specializations in #if
!defined(C10_MOBILE) since convert_float_half and convert_float_bfloat16 are
defined in headers that are guarded by !defined(C10_MOBILE). The rest of the
existing code in the block remains available on mobile as before.

or can we trigger a merge as is?

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Apr 7, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / build

Details for Dev Infra team Raised by workflow job

@aditew01
Copy link
Copy Markdown
Collaborator

aditew01 commented Apr 7, 2026

@pytorchbot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / macos-py3-arm64 / build

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
…rch#177009)

We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64),  40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`.
Some of that overhead was addressed by pytorch#176881 which introduces a faster exp.
We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels
 for better ILP.

The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53)

While using `VectorizedN` for unrolling, I noticed that:
- we don't have a fast path to convert `VectorizedN<float>` to  `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE.
- we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from pytorch#176881  vs current | Speedup from pytorch#176881 and this PR vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42%  | -2.79% |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% |

Pull Request resolved: pytorch#177009
Approved by: https://github.com/jgong5, https://github.com/Skylion007, https://github.com/malfet
ghstack dependencies: pytorch#176881
@atalman
Copy link
Copy Markdown
Contributor

atalman commented Apr 8, 2026

@pytorchmergebot revert -c ghfirst -m "Failing internally"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@fadara01 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Apr 8, 2026
…ls (#177009)"

This reverts commit e8dab08.

Reverted #177009 on behalf of https://github.com/atalman due to Failing internally ([comment](#177009 (comment)))
@atalman
Copy link
Copy Markdown
Contributor

atalman commented Apr 8, 2026

Error:

vec128_convert.h:381:17: error: no matching function for call to 'convert_from_float'
  381 |     result[0] = convert_float_bfloat16(src[0], src[1]);
      |                 ^~~~~~~~~~~~~~~~~~~~~~
ATen/cpu/vec/vec_convert.h:153:29: note: candidate template ignored: couldn't infer template argument 'scalar_t'
  153 | inline Vectorized<scalar_t> convert_from_float(
      |                             ^
In file included from /executorch/kernels/optimized/cpu/op_exp.cpp:11:

Looks like its poiniting to: https://github.com/pytorch/executorch/blob/main/kernels/optimized/cpu/op_exp.cpp#L11

And one more error:

vec128_convert.h:390:17: error: use of undeclared identifier 'convert_float_half'
  390 |     result[0] = convert_float_half(src[0], src[1]);
      |                 ^~~~~~~~~~~~~~~~~~
In file included from /executorch/kernels/optimized/cpu/op_exp.cpp:13:

@aditew01
Copy link
Copy Markdown
Collaborator

aditew01 commented Apr 9, 2026

@atalman is it possible to publish a reproducer? or point to the possible fixes that can go in the PR ? I can't make much from the error trace published

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/linux-aarch64 linux aarch64 CI workflow ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) open source Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants