Accelerate SDPA: Implement fast exp in SVE vectorizer#177645
Accelerate SDPA: Implement fast exp in SVE vectorizer#177645fadara01 wants to merge 1 commit intogh/fadara01/13/basefrom
Conversation
fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Accurate within 1 ULP for FP16 Accurate within 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ghstack-source-id: da328b4 Pull-Request: #177645
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177645
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ba6d46e with merge base f8e48d2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@pytorchbot label "topic: not user facing" |
|
Hi @Skylion007 / @jgong5 - this similar to #176881 but for SVE vectorizer - I'd appreciate your review on this :) |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: pytorch#177645 Approved by: https://github.com/Skylion007
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: pytorch#177645 Approved by: https://github.com/Skylion007
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: pytorch#177645 Approved by: https://github.com/Skylion007
Stack from ghstack (oldest at bottom):
Similar to #176881 and #151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).
Implementation is similar to exp_u20, but:
Accuracy
Tested in a similar fashion to #17688 by iterating over all possible FP32 bit patterns and calculates ULP between:
fexp_u20with inputs in FP32, outputs converted to BF16/FP16std::expwith inputs in FP32, outputs converted to BF16/FP16From the accuracy study above, this exp is:
Performance
Using this SDPA benchmark, here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256):
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01