Skip to content

Conversation

@sushraja-msft
Copy link
Contributor

Description

This change implements FlashAttention 2 for the webgpu EP for the MHA operator.

Numbers from Alderlake device show a 2.2x speed up for prefill, which considering that Attention is 50% of prefill phase (other 50% being MatMul) implies 4x speed up for Attention with this implementation. This is inline with the expected perf gain of 2-4x with FlashAttention over regular attention.

Baseline
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       9.54997e+06   <<<<<
        avg (tokens/s): 104.817
        p50 (us):       9.49218e+06
        stddev (us):    251442
        n:              5 * 1001 token(s)
------
With FlashAttention 2
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       4.27937e+06     <<<<<
        avg (tokens/s): 233.913
        p50 (us):       4.27687e+06
        stddev (us):    5344.1
        n:              5 * 1001 token(s)

Motivation and Context

On integrated GPUs memory bandwidth is premium, Flash attention makes softmax computation (and therefore output attention vector computation) a running operation instead of maintaining full QKt attention scores in memory. As a result, we see significant improvements in prefill speed - 200% speed up measured here.

This change uses techniques from co-operative matrix multiply to use registers from a subgroup for fast in register matrix multiply. Without the co-operative matrix multiply technique ALD showed about 6.0s prefill time.

Tested on ALD/TGL intel integrated and Nvidia 4070.

Future Work

  • Fine tuning and profiling optimizations.
  • Current implement is for prefill only, a generation phase optimized FA2 implementation is possible, however attention is a tiny part of the generation phase.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Feb 4, 2025
guschmue
guschmue previously approved these changes Feb 5, 2025
@sushraja-msft sushraja-msft force-pushed the user/sushraja/flash_attention2 branch from dc9d752 to 362e969 Compare February 5, 2025 22:12
@yuslepukhin
Copy link
Member

This needs tests

@sushraja-msft
Copy link
Contributor Author

This needs tests

The multiheaded attention tests covers this change https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1608375&view=logs&j=03d099f7-378a-53dd-cb87-de113827d886&t=f858033e-9d42-5a9f-0b4d-125063ba3686

image

These tests were failing in earlier iterations of this change because attention_bias was handled incorectly.

@guschmue guschmue merged commit 82840f6 into main Feb 7, 2025
98 checks passed
@guschmue guschmue deleted the user/sushraja/flash_attention2 branch February 7, 2025 00:32
ashrit-ms pushed a commit that referenced this pull request Feb 11, 2025
### Description
This change implements FlashAttention 2 for the webgpu EP for the MHA
operator.

Numbers from Alderlake device show a 2.2x speed up for prefill, which
considering that Attention is 50% of prefill phase (other 50% being
MatMul) implies 4x speed up for Attention with this implementation. This
is inline with the expected perf gain of 2-4x with FlashAttention over
regular attention.

```
Baseline
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       9.54997e+06   <<<<<
        avg (tokens/s): 104.817
        p50 (us):       9.49218e+06
        stddev (us):    251442
        n:              5 * 1001 token(s)
------
With FlashAttention 2
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       4.27937e+06     <<<<<
        avg (tokens/s): 233.913
        p50 (us):       4.27687e+06
        stddev (us):    5344.1
        n:              5 * 1001 token(s)
```

### Motivation and Context

On integrated GPUs memory bandwidth is premium, Flash attention makes
softmax computation (and therefore output attention vector computation)
a running operation instead of maintaining full QKt attention scores in
memory. As a result, we see significant improvements in prefill speed -
200% speed up measured here.

This change uses techniques from co-operative matrix multiply to use
registers from a subgroup for fast in register matrix multiply. Without
the co-operative matrix multiply technique ALD showed about 6.0s prefill
time.

Tested on ALD/TGL intel integrated and Nvidia 4070.

### Future Work
- Fine tuning and profiling optimizations.
- Current implement is for prefill only, a generation phase optimized
FA2 implementation is possible, however attention is a tiny part of the
generation phase.
ashrit-ms added a commit that referenced this pull request Feb 11, 2025
### Description
This PR is to update the win-ort-main branch to the tip main branch as
of 2025-02-11.

### PR List
74c778e [WebNN EP] Automatically move input CPU tensors to ml-tensor
(#23073)
3775057 use correct total length to fix static kv_cache performance
(#23615)
3901e96 remove --use_vcpkg flag for Python-CUDA-Packaging-Pipeline
(#23631)
c610df5 Add python_requires to package metadata (#23604)
2d27d68 [QNN EP] Add QNN EP to ARM64X build targets (#23635)
e666503 [webgpu] no longer need pass-in gpu adapter for custom
context (#23593)
af679a0 Fix logic for selecting alternate name for blob (#23617)
e206950 [ARM CPU] Add fp16 mlas kernels for exp, tanh, softmax,
logsoftmax, softcap (#23597)
9ba5619 Update pybind and json to the latest (#23589)
c54736c Migrate iOS release pipeline to 1 ES (#23606)
3981326 Increase timeout for Windows TensorRT CI (#23625)
0274b7b fix on trtCudaVersion (#23616)
740e9ab update run CI script (#23621)
5ef1832 [WebGPU] Support PIX Capture for WebGPU EP (#23192)
0114551 Fix for C4267 warning (#23610)
002916a Validate the context_file_path before EP compile graphs
(#23611)
0887e36 [webgpu] Use pushErrorScope()/popErrorScope() once for an
inference run (#23438)
65008cb Auto-generated baselines by 1ES Pipeline Templates (#23603)
09e5724 [CUDA] Fix beam search of num_beams > 32 (#23599)
82840f6 Implement Flash Attention 2 for webgpu EP (#23576)
a6ea57b OpenVINO EP Weights Sharing Feature (#23553)
2c2ff4a [CUDA] Fix BeamSearchTest.DummyT5WithSequenceInputIds test
failure in Windows (#23596)
d981b15 [webgpu/js] Optimize resize webgpu op & fix precision issues
(#23591)
328a13c Enable VCPKG in more pipelines (#23590)
6728d60 [TensorRT EP] support TensorRT 10.8-GA (#23592)
d1fb58b Quantization tool: Allow user to override calibrator's
session EP (#23559)
649ced4 Enable user loading model with external data from memory
buffer (#23557)
544bdd6 Fix ConvTranspose for certain attribute combinations (#23488)
8f6ddf3 Delete extra cgmanifest entries and files (#23583)
5f6a315 Enable VCPKG in CI build (#23426)
e1e3f62 Bump lintrunner from 0.12.5 to 0.12.7 (#23326)
cd8775f Fix Node JS Samples (#23581)
6b4f9c4 [WebGPU EP] Batch Norm Implementation (#23525)
1fce51b Fix all instances of 4244 and 4267 warnings in OV EP code
(#23567)
c29ca1c Update QNN default version to 2.31 (#23573)
2fc75a4 [mobile] Add Android BrowserStack test project back (#23551)
9e18b6a [CUDA] Update nvcc flags (#23572)
b47e1e6 [QNN EP] Make offloading graph input/output quantization (to
CPU) the default (#23368)
75a9b40 [ROCm] Update CI to use rocm 6.3.2 (#23577)
26ff2b6 Bump ruff from 0.9.3 to 0.9.4 (#23563)
b2560a7 Update react-native to 0.72 (#23509)
faee912 [js] update JavaScript API to support QNN EP options (#23486)
816e8cb [EP Perf] Update env to ubuntu 22.04 (#23570)
cddc271 Use Eigen in Round implementation (#23571)
e8b0bdb Shape inference: ReduceMean dispatcher, quant_pre_process:
skip_symbolic_shape bugfix (#23558)
267b493 delete the supported domain version upper bounds (#23237)
bb7f961 remove log spam from cpuinfo (#23548)
169917b Use latest vcpkg commit in configuration, sync manifest with
deps.txt (#23554)
a9d4d08 Add of ReduceMax Gradient (#23501)
6bbf1bd [js/web] upgrade version of flatbuffers (#23545)
271c509 DP4AMatMul perf refinements (#23539)
cb69c59 Add fusions for SigLIP and Conformer-Encoder (#23528)
61fae9b Remove "--enable_pybind" from webgpu pipeline (#23550)
0bb4ea6 Update BiasGelu fusion and related ops (#23518)
4dde74a Add more details to BrowserStack script failure (#23520)
ead9d5c Set ANDROID_USE_LEGACY_TOOLCHAIN_FILE to false (#23544)
7e24088 Enable dlpack by default (#23110)
dc2f7a9 Add overload of `TryParseStringWithClassicLocale()` that uses
`std::from_chars()` (#23541)
5407c69 Fix the issue that the new generated EP context model not
able to find external data (#23537)
fbae88f [js/web] use the recommended workaround for Vite (#23531)
d5338da Fix tensor external data info length parsing issue. (#23526)
e3e4173 [ROCm EP] Fix transpose helper for gfx gridsize constraints
(#23527)
80bc1d2 Enable Ep context with external data for CPU nodes (#23498)
bf023ab [js/web] allow import .mjs/.wasm file (#23487)
655a23f [onnxruntime/build] Add new flag enable_generic_interface to
build primary EPs by default (#23342)
a770a8d Update RN to 0.71.19 (#23381)
1cf0ebd Delete Prefast workflow until the build failure is fixed
(#23510)
d2c5e24 Add of GlobalMaxPool Gradient (#23502)
ded8730 Remove thrust::unary_function (#23506)
8db97a6 [webgpu] Bump version of Dawn to b9b4a370 (#23494)
fdde2e2 Fix for gcc 13.3.1: Avoid creating a copy (#23500)
96ec1dd Bump ruff from 0.9.2 to 0.9.3 (#23496)
42f0c00 Adds the new System.Numerics.Tensors as an input/output type
when using dotnet 8.0 and up. (#23261)
97c2bbe Fix shape infer of onnx GroupNorm (#23477)
1fc9c48 Enable coremltools for Linux build (#23481)
13348c5 [ARM CPU] hgemm optimized for gqa (#23107)
c89a798 Enable opti on Microsoft.ML.OnnxRuntime with RelWithDebInfo
config (#23463)
d00ae32 Revert "[Mobile] Add BrowserStack Android MAUI Test (#23383)"
(#23474)
8b1d3b3 Align AvgPool ceil_mode on last value to torch (#16752)
06fc73b [TRT EP Perf Tool] Add annotations import to python script to
support annotations on Python 3.8 (#23466)

### Motivation and Context
This update includes the change to add QNN EP to ARM64X build targets.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: Adrian Lizarraga <adlizarraga@microsoft.com>
Co-authored-by: Ti-Tai Wang <titaiwang@microsoft.com>
Co-authored-by: Caroline Zhu <wolfivyaura@gmail.com>
Co-authored-by: Grégoire <gregoire.verdier@gmail.com>
Co-authored-by: Jing Fang <126209182+fajin-corp@users.noreply.github.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: Yateng Hong <yatengh@microsoft.com>
Co-authored-by: Michael Sharp <51342856+michaelgsharp@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Malik Shahzad Muzaffar <shahzad.malik.muzaffar@cern.ch>
Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Co-authored-by: Dmitri Smirnov <yuslepukhin@users.noreply.github.com>
Co-authored-by: Corentin Maravat <101636442+cocotdf@users.noreply.github.com>
Co-authored-by: Jian Chen <cjian@microsoft.com>
Co-authored-by: Karim Vadsariya <karim.vadsariya@microsoft.com>
Co-authored-by: Lei Cao <jslhcl@gmail.com>
Co-authored-by: Karim Vadsariya <kvadsariya@microsoft.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hector Li <hecli@microsoft.com>
Co-authored-by: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com>
Co-authored-by: Ted Themistokleous <tedthemistokleous@amd.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: Takeshi Watanabe <take-cheeze@users.noreply.github.com>
Co-authored-by: Xavier Dupré <xadupre@users.noreply.github.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Co-authored-by: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com>
Co-authored-by: PARK DongHa <luncliff@gmail.com>
Co-authored-by: George Wu <jywu@microsoft.com>
Co-authored-by: Xinpeng Dou <15529241576@163.com>
Co-authored-by: Jambay Kinley <jambaykinley@microsoft.com>
Co-authored-by: Yifan Li <109183385+yf711@users.noreply.github.com>
Co-authored-by: Gavin Kinsey <98115505+ms-gavinkinsey@users.noreply.github.com>
Co-authored-by: Prathik Rao <prathik.rao@gmail.com>
Co-authored-by: Jon Campbell <jcampbell@cephable.com>
Co-authored-by: Satya Kumar Jandhyala <satya.k.jandhyala@gmail.com>
Co-authored-by: Joshua Lochner <admin@xenova.com>
Co-authored-by: Ankit Maheshkar <ankit.maheshkar@intel.com>
Co-authored-by: jatinwadhwa921 <jatin.wadhwa@intel.com>
Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com>
Co-authored-by: saurabh <saurabh1.kale@intel.com>
Co-authored-by: TejalKhade28 <tejal.khade@intel.com>
Co-authored-by: sfatimar <sahar.fatima@intel.com>
Co-authored-by: Javier E. Martinez <javier.e.martinez@intel.com>
Co-authored-by: Preetha Veeramalai <preetha.veeramalai@intel.com>
Co-authored-by: Eric Crawford <eric.r.crawford@intel.com>
Co-authored-by: microsoft-github-policy-service[bot] <77245923+microsoft-github-policy-service[bot]@users.noreply.github.com>
Co-authored-by: Jie Chen <jie.a.chen@intel.com>
Co-authored-by: shaoboyan091 <shaoboyan@microsoft.com>
Co-authored-by: David Hotham <david.hotham@microsoft.com>
Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
Co-authored-by: Enrico Galli <enrico.galli@intel.com>
guschmue pushed a commit that referenced this pull request Mar 6, 2025
### Description
This change implements FlashAttention 2 for the webgpu EP for the MHA
operator.

Numbers from Alderlake device show a 2.2x speed up for prefill, which
considering that Attention is 50% of prefill phase (other 50% being
MatMul) implies 4x speed up for Attention with this implementation. This
is inline with the expected perf gain of 2-4x with FlashAttention over
regular attention.

```
Baseline
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       9.54997e+06   <<<<<
        avg (tokens/s): 104.817
        p50 (us):       9.49218e+06
        stddev (us):    251442
        n:              5 * 1001 token(s)
------
With FlashAttention 2
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       4.27937e+06     <<<<<
        avg (tokens/s): 233.913
        p50 (us):       4.27687e+06
        stddev (us):    5344.1
        n:              5 * 1001 token(s)
```

### Motivation and Context

On integrated GPUs memory bandwidth is premium, Flash attention makes
softmax computation (and therefore output attention vector computation)
a running operation instead of maintaining full QKt attention scores in
memory. As a result, we see significant improvements in prefill speed -
200% speed up measured here.

This change uses techniques from co-operative matrix multiply to use
registers from a subgroup for fast in register matrix multiply. Without
the co-operative matrix multiply technique ALD showed about 6.0s prefill
time.

Tested on ALD/TGL intel integrated and Nvidia 4070.

### Future Work
- Fine tuning and profiling optimizations.
- Current implement is for prefill only, a generation phase optimized
FA2 implementation is possible, however attention is a tiny part of the
generation phase.
ashrit-ms pushed a commit that referenced this pull request Mar 17, 2025
### Description
This change implements FlashAttention 2 for the webgpu EP for the MHA
operator.

Numbers from Alderlake device show a 2.2x speed up for prefill, which
considering that Attention is 50% of prefill phase (other 50% being
MatMul) implies 4x speed up for Attention with this implementation. This
is inline with the expected perf gain of 2-4x with FlashAttention over
regular attention.

```
Baseline
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       9.54997e+06   <<<<<
        avg (tokens/s): 104.817
        p50 (us):       9.49218e+06
        stddev (us):    251442
        n:              5 * 1001 token(s)
------
With FlashAttention 2
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       4.27937e+06     <<<<<
        avg (tokens/s): 233.913
        p50 (us):       4.27687e+06
        stddev (us):    5344.1
        n:              5 * 1001 token(s)
```

### Motivation and Context

On integrated GPUs memory bandwidth is premium, Flash attention makes
softmax computation (and therefore output attention vector computation)
a running operation instead of maintaining full QKt attention scores in
memory. As a result, we see significant improvements in prefill speed -
200% speed up measured here.

This change uses techniques from co-operative matrix multiply to use
registers from a subgroup for fast in register matrix multiply. Without
the co-operative matrix multiply technique ALD showed about 6.0s prefill
time.

Tested on ALD/TGL intel integrated and Nvidia 4070.

### Future Work
- Fine tuning and profiling optimizations.
- Current implement is for prefill only, a generation phase optimized
FA2 implementation is possible, however attention is a tiny part of the
generation phase.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants