Skip to content

Accelerate SDPA on Arm CPUs: Implement fast exp in AdvSIMD vectorizer#176881

Closed
fadara01 wants to merge 5 commits intogh/fadara01/10/basefrom
gh/fadara01/10/head
Closed

Accelerate SDPA on Arm CPUs: Implement fast exp in AdvSIMD vectorizer#176881
fadara01 wants to merge 5 commits intogh/fadara01/10/basefrom
gh/fadara01/10/head

Conversation

@fadara01
Copy link
Copy Markdown
Collaborator

@fadara01 fadara01 commented Mar 9, 2026

Stack from ghstack (oldest at bottom):

Similar to #151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to the existing exp_u20, but:

  • uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients re-tuned.
  • does not split natural log (ln) into high / low parts
  • clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f

Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch

Accuracy

Following a similar approach to #151441, I tested accuracy for fast exp using this script. The script iterates over all possible FP32 bit patterns and calculates ULP between:

  • fexp_u20 with inputs in FP32, outputs converted to BF16/FP16
  • std::exp with inputs in FP32, outputs converted to BF16/FP16

Accuracy script shows that fast exp is:

  • for FP16: accurate within a maximum of 1 FP16 ULPs
  • for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf.

Performance

Using this SDPA benchmark, here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

B Hq Hkv Lq Lk D causal gqa Speedup vs current
1 32 8 2048 2048 128 True True +9.48%
1 32 8 1 2048 128 False True -1.42% (noise)
1 16 16 6400 6400 80 False False +5.18%
1 20 20 1500 1500 64 False False +6.63%
8 20 20 1500 1500 64 False False +9.31%

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01 @drisspg @liangel-02 @howardzhang-cv

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176881

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit abffea0 with merge base 0951602 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Mar 9, 2026
fadara01 added a commit that referenced this pull request Mar 9, 2026
fast exponential intended for cases where outputs will be downcasted to
FP16 / BF16 (e.g. attention softmax).
Accurate within 1 ULP for FP16
Accurate within 1 ULP for BF16 for inputs in [-87.683, 88.376] & clamps
inputs outside this range to 0 / inf.

Implementation is similar to exp_u20, but:
  - uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients retuned.
  - does not split natural log (ln) into high / low parts
  - clamps exp(x) to 0 for x < -87.6831131f and inf for x > 88.3762589f


ghstack-source-id: f66dad7
Pull-Request: #176881
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 9, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

[ghstack-poisoned]
fadara01 added a commit that referenced this pull request Mar 9, 2026
fast exponential intended for cases where outputs will be downcasted to
FP16 / BF16 (e.g. attention softmax).
Accurate within 1 ULP for FP16
Accurate within 1 ULP for BF16 for inputs in [-87.683, 88.376] & clamps
inputs outside this range to 0 / inf.

Implementation is similar to exp_u20, but:
  - uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients retuned.
  - does not split natural log (ln) into high / low parts
  - clamps exp(x) to 0 for x < -87.6831131f and inf for x > 88.3762589f

ghstack-source-id: c1837db
Pull-Request: #176881
@fadara01
Copy link
Copy Markdown
Collaborator Author

fadara01 commented Mar 9, 2026

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Mar 9, 2026
@fadara01 fadara01 requested review from aditew01 and nikhil-arm March 9, 2026 14:07
@fadara01 fadara01 added ciflow/linux-aarch64 linux aarch64 CI workflow module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Mar 9, 2026
[ghstack-poisoned]
fadara01 added a commit that referenced this pull request Mar 9, 2026
fast exponential intended for cases where outputs will be downcasted to
FP16 / BF16 (e.g. attention softmax).
Accurate within 1 ULP for FP16
Accurate within 1 ULP for BF16 for inputs in [-87.683, 88.376] & clamps
inputs outside this range to 0 / inf.

Implementation is similar to exp_u20, but:
  - uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients retuned.
  - does not split natural log (ln) into high / low parts
  - clamps exp(x) to 0 for x < -87.6831131f and inf for x > 88.3762589f

ghstack-source-id: 76b5508
Pull-Request: #176881
@fadara01
Copy link
Copy Markdown
Collaborator Author

fadara01 commented Mar 9, 2026

Hi @Skylion007 - thanks for your review!

I addressed your comments :)

[ghstack-poisoned]
[ghstack-poisoned]
// exp(r) ~ poly(r) = r + r^2 * (c3 + c2 * r)

// n = round(x / ln2), r = x - n*ln2
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use vrndnq_f32 which is RNE here?

Copy link
Copy Markdown
Collaborator Author

@fadara01 fadara01 Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @jgong5 :)
I used vrndaq_f32 intentionally here for consistency with the existing NEON exp_u20 implementation in PyTorch, which also uses vrndaq_f32. It also matches the AArch64 AdvSIMD expf implementation in glibc.
I kept the same tie-breaking behavior in this kernel, and the ULP measurements I collected were with that behavior.

Do you have strong opinions on this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have strong opinions on this?

No, keeping the original behavior sounds good to me.

@Nicoshev
Copy link
Copy Markdown
Contributor

Thanks @fadara01 for this faster exp implementation.

What's the ULP of the new algorithm when the output remains as FP32?

Couldn't this regress quality of other applications using exp?
Isn't better to introduce a new function name and call it specifically on the SDPA use case ?

@fadara01
Copy link
Copy Markdown
Collaborator Author

fadara01 commented Mar 10, 2026

What's the ULP of the new algorithm when the output remains as FP32?

@Nicoshev I did not measure that, and I don't think it matters. We'll only ever use this fexp_u20 when we intend to downcast outputs for FP16 / BF16 - currently it's only used here in attention

The semantics of this exp were defined by #151441

Couldn't this regress quality of other applications using exp?

Nope, these would use the existing exp_u20 - this fexp_u20 is only used for SDPA

Isn't better to introduce a new function name and call it specifically on the SDPA use case ?

Yes, this is exactly what fexp_u20 is! See first reply :)

@Nicoshev
Copy link
Copy Markdown
Contributor

What's the ULP of the new algorithm when the output remains as FP32?

@Nicoshev I did not measure that, and I don't think it matters. We'll only ever use this fexp_u20 when we intend to downcast outputs for FP16 / BF16 - currently it's only used here in attention

The semantics of this exp were defined by #151441

Couldn't this regress quality of other applications using exp?

Nope, these would use the existing exp_u20 - this fexp_u20 is only used for SDPA

Isn't better to introduce a new function name and call it specifically on the SDPA use case ?

Yes, this is exactly what fexp_u20 is! See first reply :)

Ah ok, sounds good, thanks for the clarification

@fadara01 fadara01 requested a review from jgong5 March 10, 2026 15:49
@fadara01
Copy link
Copy Markdown
Collaborator Author

@Skylion007 @jgong5 @Nicoshev I'd really appreciate if you guys could give this another look and help getting it merged!

sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
fast exponential intended for cases where outputs will be downcasted to
FP16 / BF16 (e.g. attention softmax).
Accurate within 1 ULP for FP16
Accurate within 1 ULP for BF16 for inputs in [-87.683, 88.376] & clamps
inputs outside this range to 0 / inf.

Implementation is similar to exp_u20, but:
  - uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients retuned.
  - does not split natural log (ln) into high / low parts
  - clamps exp(x) to 0 for x < -87.6831131f and inf for x > 88.3762589f

ghstack-source-id: fe32c18
Pull-Request: pytorch/pytorch#176881
@jgong5
Copy link
Copy Markdown
Collaborator

jgong5 commented Mar 15, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 15, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@fadara01 fadara01 changed the title Accelerate SDPA on Arm CPUs: Implement fast exp in ASIDM vectorizer Accelerate SDPA on Arm CPUs: Implement fast exp in AdvSIMD vectorizer Mar 16, 2026
pytorchmergebot pushed a commit that referenced this pull request Mar 16, 2026
)

We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64),  40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`.
Some of that overhead was addressed by #176881 which introduces a faster exp.
We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels
 for better ILP.

The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53)

While using `VectorizedN` for unrolling, I noticed that:
- we don't have a fast path to convert `VectorizedN<float>` to  `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE.
- we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from #176881  vs current | Speedup from #176881 and this PR vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42%  | -2.79% |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% |

Pull Request resolved: #177009
Approved by: https://github.com/jgong5, https://github.com/Skylion007
ghstack dependencies: #176881
pytorchmergebot pushed a commit that referenced this pull request Mar 18, 2026
Similar to #176881 and #151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to exp_u20, but:
  - approximates exp(r) - 1 as r instead of r + 0.5 r^2
  - does not split natural log (ln) into high / low parts
  - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717

## Accuracy

Tested in a similar fashion to #17688 by iterating over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

From the accuracy study above, this exp is:
- Accurate within a maximum of 1 ULP for FP16
- Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256):

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% |

Pull Request resolved: #177645
Approved by: https://github.com/Skylion007
ryanzhang22 pushed a commit to ryanzhang22/pytorch that referenced this pull request Mar 19, 2026
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to exp_u20, but:
  - approximates exp(r) - 1 as r instead of r + 0.5 r^2
  - does not split natural log (ln) into high / low parts
  - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717

## Accuracy

Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

From the accuracy study above, this exp is:
- Accurate within a maximum of 1 ULP for FP16
- Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256):

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% |

Pull Request resolved: pytorch#177645
Approved by: https://github.com/Skylion007
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ytorch#176881)

Similar to pytorch#151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to the existing exp_u20, but:
  - uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients re-tuned.
  - does not split natural log (ln) into high / low parts
  - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f

Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch

## Accuracy

Following a similar approach to pytorch#151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

Accuracy script shows that fast exp is:
- for FP16: accurate within a maximum of 1 FP16 ULPs
- for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% |

Pull Request resolved: pytorch#176881
Approved by: https://github.com/Skylion007, https://github.com/jgong5
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…rch#177009)

We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64),  40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`.
Some of that overhead was addressed by pytorch#176881 which introduces a faster exp.
We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels
 for better ILP.

The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53)

While using `VectorizedN` for unrolling, I noticed that:
- we don't have a fast path to convert `VectorizedN<float>` to  `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE.
- we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from pytorch#176881  vs current | Speedup from pytorch#176881 and this PR vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42%  | -2.79% |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% |

Pull Request resolved: pytorch#177009
Approved by: https://github.com/jgong5, https://github.com/Skylion007
ghstack dependencies: pytorch#176881
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to exp_u20, but:
  - approximates exp(r) - 1 as r instead of r + 0.5 r^2
  - does not split natural log (ln) into high / low parts
  - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717

## Accuracy

Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

From the accuracy study above, this exp is:
- Accurate within a maximum of 1 ULP for FP16
- Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256):

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% |

Pull Request resolved: pytorch#177645
Approved by: https://github.com/Skylion007
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…ytorch#176881)

Similar to pytorch#151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to the existing exp_u20, but:
  - uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients re-tuned.
  - does not split natural log (ln) into high / low parts
  - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f

Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch

## Accuracy

Following a similar approach to pytorch#151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

Accuracy script shows that fast exp is:
- for FP16: accurate within a maximum of 1 FP16 ULPs
- for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% |

Pull Request resolved: pytorch#176881
Approved by: https://github.com/Skylion007, https://github.com/jgong5
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…rch#177009)

We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64),  40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`.
Some of that overhead was addressed by pytorch#176881 which introduces a faster exp.
We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels
 for better ILP.

The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53)

While using `VectorizedN` for unrolling, I noticed that:
- we don't have a fast path to convert `VectorizedN<float>` to  `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE.
- we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from pytorch#176881  vs current | Speedup from pytorch#176881 and this PR vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42%  | -2.79% |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% |

Pull Request resolved: pytorch#177009
Approved by: https://github.com/jgong5, https://github.com/Skylion007
ghstack dependencies: pytorch#176881
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to exp_u20, but:
  - approximates exp(r) - 1 as r instead of r + 0.5 r^2
  - does not split natural log (ln) into high / low parts
  - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717

## Accuracy

Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

From the accuracy study above, this exp is:
- Accurate within a maximum of 1 ULP for FP16
- Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256):

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% |

Pull Request resolved: pytorch#177645
Approved by: https://github.com/Skylion007
pytorchmergebot pushed a commit that referenced this pull request Apr 7, 2026
)

We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64),  40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`.
Some of that overhead was addressed by #176881 which introduces a faster exp.
We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels
 for better ILP.

The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53)

While using `VectorizedN` for unrolling, I noticed that:
- we don't have a fast path to convert `VectorizedN<float>` to  `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE.
- we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from #176881  vs current | Speedup from #176881 and this PR vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42%  | -2.79% |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% |

Pull Request resolved: #177009
Approved by: https://github.com/jgong5, https://github.com/Skylion007, https://github.com/malfet
ghstack dependencies: #176881
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
…rch#177009)

We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64),  40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`.
Some of that overhead was addressed by pytorch#176881 which introduces a faster exp.
We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels
 for better ILP.

The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53)

While using `VectorizedN` for unrolling, I noticed that:
- we don't have a fast path to convert `VectorizedN<float>` to  `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE.
- we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from pytorch#176881  vs current | Speedup from pytorch#176881 and this PR vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42%  | -2.79% |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% |

Pull Request resolved: pytorch#177009
Approved by: https://github.com/jgong5, https://github.com/Skylion007, https://github.com/malfet
ghstack dependencies: pytorch#176881
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/linux-aarch64 linux aarch64 CI workflow ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants