Implement fast exp for AVX2 and AVX512 for the flash attention #151441
Implement fast exp for AVX2 and AVX512 for the flash attention #151441timocafe wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151441
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 1464c24 with merge base 6f23f53 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "topic: not user facing" |
|
Thanks for your optimization! Need to do more validations before the PR lands, such as the three dynamo suites and LZ models, which I would follow. cc @mingfeima @leslie-fang-intel |
18bb13b to
5dfa3c6
Compare
|
The validation for dynamo suites is done, and we do not see obvious accuracy/perf change for all dtypes, including bf16/fp16/fp32. Next, we will check the accuracy/perf on Stable Diffusion, Llama3.1-8b and VIT. cc @timocafe |
|
The model validation is still WIP, due to the lack of machine and the large dataset for accuracy. |
|
The model validation result is ready.
Thanks for your work! @timocafe cc @mingfeima |
|
Can you give a little more color on the precision impact, it says 1ULP in hybrid mode, is the setting here hybrid mode? Does that change for different input q,k,v dtypes? Or since this is exp we the scores are always in fp32 |
|
hi, indeed the current implementation of the flash attention CPU performed exp(x) in float32 (the template parameter input T1=float32, the output T2=float16) and very important, the results are stored/cast to BF16/F16 consequently, we can be moderately aggressive on the precision of the polynomial evaluation. When I said 1 ULP, on the side, I did the following benchmark:
I get one, expect for number beyond the boundary limit, but in this case, already PT::exp is soso. Per example: If you are interested I can furnish the benchmark code in a tiny GitHub project. To conclude this fexp_u20 is only valid in the mode f32/f16, if the flash attention is computed f32/f32 I will fall back on the curent implementation. Concerning the implementation, I am aware of implementing a dedicated exp for only a specific function may be too much (but the transformers are so popular). I could have merge both and performed a constexpr branching, but I need an extra argument to select the implementation. The consequences may be huge therefore I did the basic choice to add an extra exp, to be the most sweet possible. In an ideal world, we may select the precision of any elementary functions, following the workload, parameters, etc .... Beyond the scope of the flash attention probably this exp could be used in many place in PT, especially for f32/f16 but I do not have the knowledge of the full framework, more and less all activation functions are concerned .... |
|
Any update for the PR? |
|
Hello meta, with summer coming any hope of acceptance ? |
|
@timocafe Please rebase and make sure all the CIs pass. |
54178a3 to
e3ae0f6
Compare
Re-run the test and still timeout. The PR cannot be landed because of the issue. Suggest to rebase and trigger all the CIs again. |
- Implement Fast Exponential Computation on SIMD Architectures A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro Curioni - AVX2 and AVX512 float only, up to 20% faster for mix precision flash attention than the current implementation. - For the other types legacy implementation. Precision: 1 ULP only valid in hybrid mode fp32 -> f16 due to the cast during the store operation in the flash attention: Benchmark: Machine Xeon 6972P, results in TOPs, Python forward pass flash attention numhead 16, Head dimension 64 |Seq. L.| PT | fexp | |-------|------|------| | 512 | 0.8 | 1.3 | | 1024 | 1.7 | 1.7 | | 2048 | 6 | 6.1 | | 4096 | 16 | 16.8 | | 8192 | 30.6 | 32.3 | | 16384 | 40 | 40.8 | | 32768 | 44.9 | 51.4 | | 65536 | 45.8 | 54.4 | numhead 16, Head dimension 128 |Seq. L.| PT | fexp | |-------|------|------| | 512 | 2.5 | 4.1 | | 1024 | 3.3 | 4 | | 2048 | 11.4 | 10.5 | | 4096 | 27.4 | 28.4 | | 8192 | 44.4 | 46 | | 16384 | 64.2 | 68.1 | | 32768 | 77.8 | 83 | | 65536 | 82.1 | 88.1 | numhead 16, Head dimension 256 |Seq. L.| PT | fexp | |-------|------| | 512 | 1.7 | 3.4 | | 1024 | 4.2 | 6.5 | | 2048 | 14.6 | 16.1 | | 4096 | 30.1 | 31.1 | | 8192 | 60 | 62 | | 16384 | 83.3 | 87.3 | | 32768 | 98.7 | 106 | | 65536 | 102.2| 107.1| Fix typo and compiler issue. - retrigger the CI by ammending the commit message due to the modification of 09e8ff9 4days ago. - Formatting - modify the sve/neon/arm backend
|
rebase and push, the CI has been canceled by a mysterious: "check labels: Canceling since a higher priority waiting request for Check Labels-151441-false exists" |
Because you don't have the permission to trigger CIs. I've launched now. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…176881) Similar to #151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to the existing exp_u20, but: - uses a third degree polynomial approximation for exp(r) instead of a fifth degree one, with coefficients re-tuned. - does not split natural log (ln) into high / low parts - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch ## Accuracy Following a similar approach to #151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 Accuracy script shows that fast exp is: - for FP16: accurate within a maximum of 1 FP16 ULPs - for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | Pull Request resolved: #176881 Approved by: https://github.com/Skylion007, https://github.com/jgong5
Similar to #176881 and #151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to #17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: #177645 Approved by: https://github.com/Skylion007
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: pytorch#177645 Approved by: https://github.com/Skylion007
…ytorch#176881) Similar to pytorch#151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to the existing exp_u20, but: - uses a third degree polynomial approximation for exp(r) instead of a fifth degree one, with coefficients re-tuned. - does not split natural log (ln) into high / low parts - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch ## Accuracy Following a similar approach to pytorch#151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 Accuracy script shows that fast exp is: - for FP16: accurate within a maximum of 1 FP16 ULPs - for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | Pull Request resolved: pytorch#176881 Approved by: https://github.com/Skylion007, https://github.com/jgong5
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: pytorch#177645 Approved by: https://github.com/Skylion007
…ytorch#176881) Similar to pytorch#151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to the existing exp_u20, but: - uses a third degree polynomial approximation for exp(r) instead of a fifth degree one, with coefficients re-tuned. - does not split natural log (ln) into high / low parts - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch ## Accuracy Following a similar approach to pytorch#151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 Accuracy script shows that fast exp is: - for FP16: accurate within a maximum of 1 FP16 ULPs - for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores: | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% | Pull Request resolved: pytorch#176881 Approved by: https://github.com/Skylion007, https://github.com/jgong5
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax). Implementation is similar to exp_u20, but: - approximates exp(r) - 1 as r instead of r + 0.5 r^2 - does not split natural log (ln) into high / low parts - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717 ## Accuracy Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between: - `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16 - `std::exp` with inputs in FP32, outputs converted to BF16/FP16 From the accuracy study above, this exp is: - Accurate within a maximum of 1 ULP for FP16 - Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero. ## Performance Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256): | B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current | |---:|---:|---:|---:|---:|---:|---|---|---:| | 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% | | 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) | | 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% | | 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% | | 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% | Pull Request resolved: pytorch#177645 Approved by: https://github.com/Skylion007
Implement fexp for avx2 and avx512
Cristiano and all propose a clever exp using the IEEE representation with a fine control of the precision, especially useful
for mix computation of the flash attention.
A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro Curioni
than the current implementation.
Precision
1 ULP only valid in hybrid mode fp32 -> f16 due to the cast during the
store operation in the flash attention:
Benchmark
Machine Xeon 6972P, results in TOPs, Python forward pass flash attention
numhead 16, Head dimension 64
numhead 16, Head dimension 128
numhead 16, Head dimension 256
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168