Accelerate SDPA on Arm CPUs: Implement fast exp in AdvSIMD vectorizer#176881
Accelerate SDPA on Arm CPUs: Implement fast exp in AdvSIMD vectorizer#176881fadara01 wants to merge 5 commits intogh/fadara01/10/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176881
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit abffea0 with merge base 0951602 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
fast exponential intended for cases where outputs will be downcasted to
FP16 / BF16 (e.g. attention softmax).
Accurate within 1 ULP for FP16
Accurate within 1 ULP for BF16 for inputs in [-87.683, 88.376] & clamps
inputs outside this range to 0 / inf.
Implementation is similar to exp_u20, but:
- uses a third degree polynomial approximation for exp(r) instead of a
fifth degree one, with coefficients retuned.
- does not split natural log (ln) into high / low parts
- clamps exp(x) to 0 for x < -87.6831131f and inf for x > 88.3762589f
ghstack-source-id: f66dad7
Pull-Request: #176881
This PR needs a
|
fast exponential intended for cases where outputs will be downcasted to
FP16 / BF16 (e.g. attention softmax).
Accurate within 1 ULP for FP16
Accurate within 1 ULP for BF16 for inputs in [-87.683, 88.376] & clamps
inputs outside this range to 0 / inf.
Implementation is similar to exp_u20, but:
- uses a third degree polynomial approximation for exp(r) instead of a
fifth degree one, with coefficients retuned.
- does not split natural log (ln) into high / low parts
- clamps exp(x) to 0 for x < -87.6831131f and inf for x > 88.3762589f
ghstack-source-id: c1837db
Pull-Request: #176881
|
@pytorchbot label "topic: not user facing" |
fast exponential intended for cases where outputs will be downcasted to
FP16 / BF16 (e.g. attention softmax).
Accurate within 1 ULP for FP16
Accurate within 1 ULP for BF16 for inputs in [-87.683, 88.376] & clamps
inputs outside this range to 0 / inf.
Implementation is similar to exp_u20, but:
- uses a third degree polynomial approximation for exp(r) instead of a
fifth degree one, with coefficients retuned.
- does not split natural log (ln) into high / low parts
- clamps exp(x) to 0 for x < -87.6831131f and inf for x > 88.3762589f
ghstack-source-id: 76b5508
Pull-Request: #176881
|
Hi @Skylion007 - thanks for your review! I addressed your comments :) |
| // exp(r) ~ poly(r) = r + r^2 * (c3 + c2 * r) | ||
|
|
||
| // n = round(x / ln2), r = x - n*ln2 | ||
| float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); |
There was a problem hiding this comment.
Shall we use vrndnq_f32 which is RNE here?
There was a problem hiding this comment.
Thanks for the review @jgong5 :)
I used vrndaq_f32 intentionally here for consistency with the existing NEON exp_u20 implementation in PyTorch, which also uses vrndaq_f32. It also matches the AArch64 AdvSIMD expf implementation in glibc.
I kept the same tie-breaking behavior in this kernel, and the ULP measurements I collected were with that behavior.
Do you have strong opinions on this?
There was a problem hiding this comment.
Do you have strong opinions on this?
No, keeping the original behavior sounds good to me.
|
Thanks @fadara01 for this faster exp implementation. What's the ULP of the new algorithm when the output remains as FP32? Couldn't this regress quality of other applications using exp? |
@Nicoshev I did not measure that, and I don't think it matters. We'll only ever use this The semantics of this exp were defined by #151441
Nope, these would use the existing
Yes, this is exactly what |
Ah ok, sounds good, thanks for the clarification |
|
@Skylion007 @jgong5 @Nicoshev I'd really appreciate if you guys could give this another look and help getting it merged! |
fast exponential intended for cases where outputs will be downcasted to
FP16 / BF16 (e.g. attention softmax).
Accurate within 1 ULP for FP16
Accurate within 1 ULP for BF16 for inputs in [-87.683, 88.376] & clamps
inputs outside this range to 0 / inf.
Implementation is similar to exp_u20, but:
- uses a third degree polynomial approximation for exp(r) instead of a
fifth degree one, with coefficients retuned.
- does not split natural log (ln) into high / low parts
- clamps exp(x) to 0 for x < -87.6831131f and inf for x > 88.3762589f
ghstack-source-id: fe32c18
Pull-Request: pytorch/pytorch#176881
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64), 40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`. Some of that overhead was addressed by #176881 which introduces a faster exp. We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels for better ILP. The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53) While using `VectorizedN` for unrolling, I noticed that: - we don't have a fast path to convert `VectorizedN<float>` to `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE. - we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from #176881 vs current | Speedup from #176881 and this PR vs current | |---:|---:|---:|---:|---:|---:|---|---|---:|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% | -2.79% | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% | Pull Request resolved: #177009 Approved by: https://github.com/jgong5, https://github.com/Skylion007 ghstack dependencies: #176881
Similar to #176881 and #151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to #17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: #177645 Approved by: https://github.com/Skylion007
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: pytorch#177645 Approved by: https://github.com/Skylion007
…ytorch#176881) Similar to pytorch#151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to the existing exp_u20, but: - uses a third degree polynomial approximation for exp(r) instead of a fifth degree one, with coefficients re-tuned. - does not split natural log (ln) into high / low parts - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch ## Accuracy Following a similar approach to pytorch#151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 Accuracy script shows that fast exp is: - for FP16: accurate within a maximum of 1 FP16 ULPs - for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | Pull Request resolved: pytorch#176881 Approved by: https://github.com/Skylion007, https://github.com/jgong5
…rch#177009) We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64), 40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`. Some of that overhead was addressed by pytorch#176881 which introduces a faster exp. We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels for better ILP. The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53) While using `VectorizedN` for unrolling, I noticed that: - we don't have a fast path to convert `VectorizedN<float>` to `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE. - we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from pytorch#176881 vs current | Speedup from pytorch#176881 and this PR vs current | |---:|---:|---:|---:|---:|---:|---|---|---:|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% | -2.79% | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% | Pull Request resolved: pytorch#177009 Approved by: https://github.com/jgong5, https://github.com/Skylion007 ghstack dependencies: pytorch#176881
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: pytorch#177645 Approved by: https://github.com/Skylion007
…ytorch#176881) Similar to pytorch#151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to the existing exp_u20, but: - uses a third degree polynomial approximation for exp(r) instead of a fifth degree one, with coefficients re-tuned. - does not split natural log (ln) into high / low parts - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch ## Accuracy Following a similar approach to pytorch#151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 Accuracy script shows that fast exp is: - for FP16: accurate within a maximum of 1 FP16 ULPs - for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | Pull Request resolved: pytorch#176881 Approved by: https://github.com/Skylion007, https://github.com/jgong5
…rch#177009) We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64), 40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`. Some of that overhead was addressed by pytorch#176881 which introduces a faster exp. We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels for better ILP. The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53) While using `VectorizedN` for unrolling, I noticed that: - we don't have a fast path to convert `VectorizedN<float>` to `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE. - we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from pytorch#176881 vs current | Speedup from pytorch#176881 and this PR vs current | |---:|---:|---:|---:|---:|---:|---|---|---:|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% | -2.79% | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% | Pull Request resolved: pytorch#177009 Approved by: https://github.com/jgong5, https://github.com/Skylion007 ghstack dependencies: pytorch#176881
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: pytorch#177645 Approved by: https://github.com/Skylion007
) We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64), 40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`. Some of that overhead was addressed by #176881 which introduces a faster exp. We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels for better ILP. The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53) While using `VectorizedN` for unrolling, I noticed that: - we don't have a fast path to convert `VectorizedN<float>` to `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE. - we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from #176881 vs current | Speedup from #176881 and this PR vs current | |---:|---:|---:|---:|---:|---:|---|---|---:|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% | -2.79% | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% | Pull Request resolved: #177009 Approved by: https://github.com/jgong5, https://github.com/Skylion007, https://github.com/malfet ghstack dependencies: #176881
…rch#177009) We noticed that (e.g. in whisper with: B=8, num_heads=20, seqlen=1500, head_dim=64), 40% of our time in scaled-dot-production-attention is spent in the `_exp_reduce_sum_fusion_kernel`. Some of that overhead was addressed by pytorch#176881 which introduces a faster exp. We build on top of that here, and squeeze more perf out of SDPA through unrolling exp_sum and max_mul kernels for better ILP. The unrolling pattern used here is already present in other SDPA helper kernels like [_scale_attn_mask_fusion_kernel](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L53) While using `VectorizedN` for unrolling, I noticed that: - we don't have a fast path to convert `VectorizedN<float>` to `VectorizedN<bfloat16>` for NEON, so I added that. We already have that for SVE. - we don't have a fast path to short-circuit identity conversions for `VectorizedN` (e.g. `VectorizedN<float>` to `VectorizedN<float>`), so I added that too ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup from pytorch#176881 vs current | Speedup from pytorch#176881 and this PR vs current | |---:|---:|---:|---:|---:|---:|---|---|---:|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | +14.91% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% | -2.79% | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | +11.60% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | +11.80% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | +17.12% | Pull Request resolved: pytorch#177009 Approved by: https://github.com/jgong5, https://github.com/Skylion007, https://github.com/malfet ghstack dependencies: pytorch#176881
Stack from ghstack (oldest at bottom):
Similar to #151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).
Implementation is similar to the existing exp_u20, but:
fifth degree one, with coefficients re-tuned.
Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch
Accuracy
Following a similar approach to #151441, I tested accuracy for fast exp using this script. The script iterates over all possible FP32 bit patterns and calculates ULP between:
fexp_u20with inputs in FP32, outputs converted to BF16/FP16std::expwith inputs in FP32, outputs converted to BF16/FP16Accuracy script shows that fast exp is:
Performance
Using this SDPA benchmark, here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01 @drisspg @liangel-02 @howardzhang-cv