Skip to content

Implement fast exp for AVX2 and AVX512 for the flash attention #151441

Closed
timocafe wants to merge 1 commit intopytorch:mainfrom
timocafe:tewart_fexp
Closed

Implement fast exp for AVX2 and AVX512 for the flash attention #151441
timocafe wants to merge 1 commit intopytorch:mainfrom
timocafe:tewart_fexp

Conversation

@timocafe
Copy link
Copy Markdown
Contributor

@timocafe timocafe commented Apr 16, 2025

Implement fexp for avx2 and avx512

Cristiano and all propose a clever exp using the IEEE representation with a fine control of the precision, especially useful
for mix computation of the flash attention.

  • Implement Fast Exponential Computation on SIMD Architectures
    A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro Curioni
  • AVX2 and AVX512 float only, up to 20% faster for mix precision flash attention
    than the current implementation.
  • For the other types legacy implementation.

Precision

1 ULP only valid in hybrid mode fp32 -> f16 due to the cast during the
store operation in the flash attention:

Benchmark

Machine Xeon 6972P, results in TOPs, Python forward pass flash attention

numhead 16, Head dimension 64

Seq. L. PT fexp
512 0.8 1.3
1024 1.7 1.7
2048 6 6.1
4096 16 16.8
8192 30.6 32.3
16384 40 40.8
32768 44.9 51.4
65536 45.8 54.4

numhead 16, Head dimension 128

Seq. L. PT fexp
512 2.5 4.1
1024 3.3 4
2048 11.4 10.5
4096 27.4 28.4
8192 44.4 46
16384 64.2 68.1
32768 77.8 83
65536 82.1 88.1

numhead 16, Head dimension 256

Seq. L. PT fexp
512 1.7 3.4
1024 4.2 6.5
2048 14.6 16.1
4096 30.1 31.1
8192 60 62
16384 83.3 87.3
32768 98.7 106
65536 102.2 107.1

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151441

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 1464c24 with merge base 6f23f53 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Apr 16, 2025
@timocafe
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Apr 16, 2025
@drisspg drisspg added the module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion label Apr 16, 2025
@Valentine233
Copy link
Copy Markdown
Collaborator

Thanks for your optimization! Need to do more validations before the PR lands, such as the three dynamo suites and LZ models, which I would follow. cc @mingfeima @leslie-fang-intel

@albanD albanD requested a review from drisspg April 17, 2025 13:39
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 17, 2025
@mingfeima mingfeima moved this to In Progress in PyTorch Intel Apr 18, 2025
@timocafe timocafe force-pushed the tewart_fexp branch 2 times, most recently from 18bb13b to 5dfa3c6 Compare April 22, 2025 07:59
@Valentine233
Copy link
Copy Markdown
Collaborator

Valentine233 commented Apr 24, 2025

The validation for dynamo suites is done, and we do not see obvious accuracy/perf change for all dtypes, including bf16/fp16/fp32. Next, we will check the accuracy/perf on Stable Diffusion, Llama3.1-8b and VIT. cc @timocafe

@Valentine233
Copy link
Copy Markdown
Collaborator

The model validation is still WIP, due to the lack of machine and the large dataset for accuracy.

@Valentine233
Copy link
Copy Markdown
Collaborator

Valentine233 commented May 14, 2025

The model validation result is ready.

  • Accuracy: All is good.
  • Performance: For stable diffusion v2.1, we see an improvement of 5% for BF16 and 4% for FP16. No other obvious impacts.

Thanks for your work! @timocafe cc @mingfeima

@mingfeima mingfeima requested a review from malfet May 19, 2025 06:47
@timocafe
Copy link
Copy Markdown
Contributor Author

@malfet @drisspg any update for my PR ?

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Jun 7, 2025

Can you give a little more color on the precision impact, it says 1ULP in hybrid mode, is the setting here hybrid mode? Does that change for different input q,k,v dtypes? Or since this is exp we the scores are always in fp32

@timocafe
Copy link
Copy Markdown
Contributor Author

timocafe commented Jun 8, 2025

hi, indeed the current implementation of the flash attention CPU performed exp(x) in float32 (the template parameter input T1=float32, the output T2=float16) and very important, the results are stored/cast to BF16/F16 consequently, we can be moderately aggressive on the precision of the polynomial evaluation.

When I said 1 ULP, on the side, I did the following benchmark:

  • generate float32 numbers:
  • perform the PT exp and my "new" exp
  • convert the f32 results to c10:Half
  • compute the ULP PT::exp and my::exp with std::exp

I get one, expect for number beyond the boundary limit, but in this case, already PT::exp is soso. Per example:

x:66.7371 tim::exp ulp: 1 pt::exp ulp: 0
x:-12.9124 tim::exp ulp: 1 pt::exp ulp: 0
x:-69.4785 tim::exp ulp: 1 pt::exp ulp: 0
x:80.2029 tim::exp ulp: 1 pt::exp ulp: 0 <---- the ULP is compare with std::exp
x:-88.8228 tim::exp ulp: 29 pt::exp ulp: 29 <------ the ULP is already soso with std::exp if outside the box.
x:75.7617 tim::exp ulp: 1 pt::exp ulp: 0

If you are interested I can furnish the benchmark code in a tiny GitHub project.

To conclude this fexp_u20 is only valid in the mode f32/f16, if the flash attention is computed f32/f32 I will fall back on the curent implementation.

Concerning the implementation, I am aware of implementing a dedicated exp for only a specific function may be too much (but the transformers are so popular). I could have merge both and performed a constexpr branching, but I need an extra argument to select the implementation. The consequences may be huge therefore I did the basic choice to add an extra exp, to be the most sweet possible. In an ideal world, we may select the precision of any elementary functions, following the workload, parameters, etc ....

Beyond the scope of the flash attention probably this exp could be used in many place in PT, especially for f32/f16 but I do not have the knowledge of the full framework, more and less all activation functions are concerned ....

@Valentine233
Copy link
Copy Markdown
Collaborator

Any update for the PR?

@timocafe
Copy link
Copy Markdown
Contributor Author

Hello meta, with summer coming any hope of acceptance ?

@Valentine233
Copy link
Copy Markdown
Collaborator

cc @Skylion007 @drisspg @malfet

@Valentine233 Valentine233 requested a review from Skylion007 July 2, 2025 01:37
@Valentine233
Copy link
Copy Markdown
Collaborator

@timocafe Please rebase and make sure all the CIs pass.

@timocafe timocafe force-pushed the tewart_fexp branch 2 times, most recently from 54178a3 to e3ae0f6 Compare July 3, 2025 07:34
@Valentine233
Copy link
Copy Markdown
Collaborator

The error came from Mac build, I forget to add the corresponding wrapper. Difficult to check because the ARM build are only performed during the merge operations ...

I did the change, current failure due to unknow timeout error

Re-run the test and still timeout. The PR cannot be landed because of the issue. Suggest to rebase and trigger all the CIs again.

- Implement Fast Exponential Computation on SIMD Architectures
  A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro Curioni
- AVX2 and AVX512 float only, up to 20% faster for mix precision flash attention
  than the current implementation.
- For the other types legacy implementation.

Precision: 1 ULP only valid in hybrid mode fp32 -> f16 due to the cast during the
store operation in the flash attention:

Benchmark:

Machine Xeon 6972P, results in TOPs, Python forward pass flash attention

numhead 16, Head dimension 64

|Seq. L.| PT   | fexp |
|-------|------|------|
| 512   | 0.8  | 1.3  |
| 1024  | 1.7  | 1.7  |
| 2048  | 6    | 6.1  |
| 4096  | 16   | 16.8 |
| 8192  | 30.6 | 32.3 |
| 16384 | 40   | 40.8 |
| 32768 | 44.9 | 51.4 |
| 65536 | 45.8 | 54.4 |

numhead 16, Head dimension 128

|Seq. L.| PT   | fexp |
|-------|------|------|
| 512   | 2.5  | 4.1  |
| 1024  | 3.3  | 4    |
| 2048  | 11.4 | 10.5 |
| 4096  | 27.4 | 28.4 |
| 8192  | 44.4 | 46   |
| 16384 | 64.2 | 68.1 |
| 32768 | 77.8 | 83   |
| 65536 | 82.1 | 88.1 |

numhead 16, Head dimension 256

|Seq. L.| PT   | fexp |
|-------|------|
| 512   | 1.7  | 3.4  |
| 1024  | 4.2  | 6.5  |
| 2048  | 14.6 | 16.1 |
| 4096  | 30.1 | 31.1 |
| 8192  | 60   | 62   |
| 16384 | 83.3 | 87.3 |
| 32768 | 98.7 | 106  |
| 65536 | 102.2| 107.1|

Fix typo and compiler issue.

- retrigger the CI by ammending the commit message
  due to the modification of 09e8ff9 4days ago.
- Formatting
- modify the sve/neon/arm backend
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jul 9, 2025
@timocafe
Copy link
Copy Markdown
Contributor Author

timocafe commented Jul 9, 2025

rebase and push, the CI has been canceled by a mysterious: "check labels: Canceling since a higher priority waiting request for Check Labels-151441-false exists"

@Valentine233 Valentine233 added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 10, 2025
@Valentine233
Copy link
Copy Markdown
Collaborator

rebase and push, the CI has been canceled by a mysterious: "check labels: Canceling since a higher priority waiting request for Check Labels-151441-false exists"

Because you don't have the permission to trigger CIs. I've launched now.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Jul 10, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 2 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@Valentine233
Copy link
Copy Markdown
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-project-automation github-project-automation bot moved this from In Progress to Done in PyTorch Intel Jul 10, 2025
pytorchmergebot pushed a commit that referenced this pull request Mar 15, 2026
…176881)

Similar to #151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to the existing exp_u20, but:
  - uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients re-tuned.
  - does not split natural log (ln) into high / low parts
  - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f

Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch

## Accuracy

Following a similar approach to #151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

Accuracy script shows that fast exp is:
- for FP16: accurate within a maximum of 1 FP16 ULPs
- for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% |

Pull Request resolved: #176881
Approved by: https://github.com/Skylion007, https://github.com/jgong5
pytorchmergebot pushed a commit that referenced this pull request Mar 18, 2026
Similar to #176881 and #151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to exp_u20, but:
  - approximates exp(r) - 1 as r instead of r + 0.5 r^2
  - does not split natural log (ln) into high / low parts
  - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717

## Accuracy

Tested in a similar fashion to #17688 by iterating over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

From the accuracy study above, this exp is:
- Accurate within a maximum of 1 ULP for FP16
- Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256):

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% |

Pull Request resolved: #177645
Approved by: https://github.com/Skylion007
ryanzhang22 pushed a commit to ryanzhang22/pytorch that referenced this pull request Mar 19, 2026
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to exp_u20, but:
  - approximates exp(r) - 1 as r instead of r + 0.5 r^2
  - does not split natural log (ln) into high / low parts
  - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717

## Accuracy

Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

From the accuracy study above, this exp is:
- Accurate within a maximum of 1 ULP for FP16
- Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256):

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% |

Pull Request resolved: pytorch#177645
Approved by: https://github.com/Skylion007
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ytorch#176881)

Similar to pytorch#151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to the existing exp_u20, but:
  - uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients re-tuned.
  - does not split natural log (ln) into high / low parts
  - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f

Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch

## Accuracy

Following a similar approach to pytorch#151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

Accuracy script shows that fast exp is:
- for FP16: accurate within a maximum of 1 FP16 ULPs
- for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% |

Pull Request resolved: pytorch#176881
Approved by: https://github.com/Skylion007, https://github.com/jgong5
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to exp_u20, but:
  - approximates exp(r) - 1 as r instead of r + 0.5 r^2
  - does not split natural log (ln) into high / low parts
  - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717

## Accuracy

Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

From the accuracy study above, this exp is:
- Accurate within a maximum of 1 ULP for FP16
- Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256):

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% |

Pull Request resolved: pytorch#177645
Approved by: https://github.com/Skylion007
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…ytorch#176881)

Similar to pytorch#151441 - this PR adds an ASIMD implementation for fast exponential intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to the existing exp_u20, but:
  - uses a third degree polynomial approximation for exp(r) instead of a
    fifth degree one, with coefficients re-tuned.
  - does not split natural log (ln) into high / low parts
  - clamps exp(x) to 0 for x < -87.346351f and inf for x > 88.3762589f

Overall, this allows us to ditch 4 FMLAs (out of 7 in the current impl) and the nasty if branch

## Accuracy

Following a similar approach to pytorch#151441, I tested accuracy for fast exp using [this](https://gist.github.com/fadara01/1f0a7bc4b63193102b5f6b0283a9b02f) script. The script iterates over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

Accuracy script shows that fast exp is:
- for FP16: accurate within a maximum of 1 FP16 ULPs
- for BF16: accurate within a maximum of 1 BF16 ULPs for inputs in [-87.346351, 88.376] - we clamp inputs outside this range to 0 / inf.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V2 cores:

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +9.48% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | -1.42% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +5.18% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.63% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +9.31% |

Pull Request resolved: pytorch#176881
Approved by: https://github.com/Skylion007, https://github.com/jgong5
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
Similar to pytorch#176881 and pytorch#151441 this PR adds an SVE fast exponential implementation, intended for cases where outputs will be downcasted to FP16 / BF16 (e.g. attention softmax).

Implementation is similar to exp_u20, but:
  - approximates exp(r) - 1 as r instead of r + 0.5 r^2
  - does not split natural log (ln) into high / low parts
  - avoids special case code by clamping exp(x) to 0 for x < -87.346 and inf for x > 88.717

## Accuracy

Tested in a similar fashion to pytorch#17688 by iterating over all possible FP32 bit patterns and calculates ULP between:
- `fexp_u20` with inputs in FP32, outputs converted to BF16/FP16
- `std::exp` with inputs in FP32, outputs converted to BF16/FP16

From the accuracy study above, this exp is:
- Accurate within a maximum of 1 ULP for FP16
- Accurate within a maximum of 1 ULP for BF16 for inputs in [-87.346, max_float] & clamps inputs < -87.346 to zero.

## Performance

Using [this SDPA benchmark](https://gist.github.com/fadara01/5357a52299a3722587f6691d145e71e9), here are the scaled-dot-production-attention speedups achieved with 16 Neoverse-V1 cores (with SVE256):

| B | Hq | Hkv | Lq | Lk | D | causal | gqa | Speedup vs current |
|---:|---:|---:|---:|---:|---:|---|---|---:|
| 1 | 32 | 8 | 2048 | 2048 | 128 | True | True | +7.20% |
| 1 | 32 | 8 | 1 | 2048 | 128 | False | True | +0.38% (noise) |
| 1 | 16 | 16 | 6400 | 6400 | 80 | False | False | +4.32% |
| 1 | 20 | 20 | 1500 | 1500 | 64 | False | False | +3.38% |
| 8 | 20 | 20 | 1500 | 1500 | 64 | False | False | +6.35% |

Pull Request resolved: pytorch#177645
Approved by: https://github.com/Skylion007
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

9 participants