Skip to content

Conversation

@reeselevine
Copy link
Collaborator

This PR adds an initial version of FlashAttention2 in WebGPU. Along with the GPU code itself, this PR also adds a new preprocessor for WGSL shaders that should make it easier/less brittle to define new shaders going forwards. Details below:

Shader setup

  • Most of the shaders right now are generated at build time using a relatively hacky Python script and template syntax that I wrote when doing initial development for the WebGPU backend. However, this probably won't continue to scale very well, especially with the number of options for FlashAttention, so I decided it was time to use a more general solution. However, there wasn't an existing preprocessor for WGSL that would work with C++ code, so I wrote one here: https://github.com/reeselevine/pre-wgsl. The preprocessor itself is one file, pre_wgsl.hpp, and should continue to track any changes/features added to the main preprocessor repository.
  • To accommodate the various options for FlashAttention when compiling WGSL shaders, I added another new file to the WebGPU backend, ggml-webgpu-shader-lib.hpp, which generates the shader using a combination of structural parameters, e.g., head sizes, and performance parameters, like KV tile sizes. My idea is to expand this library approach use more sophisticated strategies going forwards and move other shaders over to use the preprocessor. From a performance perspective, shaders compile pretty fast, are mostly fixed for a given model, and are cached, so JIT compilation, at least for FlashAttention, seems to be the right call in my opinion.
  • Some other minor changes to ggml-webgpu.cpp to handle the new FlashAttention code and JIT compilation.

FlashAttention shader itself

  • For the most part this follows the FlashAttention2 paper, with a change for online softmax to make it subgroup-size agnostic. It also uses global KV loads if sizes are nicely divisible, since the KV tiles are not reused and pre-loading into shared memory really slows things down (at least on my M3).
  • Even then, performance is not great right now (< 50% of the same Metal code). My testing shows that a lot of this slowdown basically boils down to the initial Q * K^T accumulation loop. I still need to do some debugging here, to figure out if the issue is something I'm doing wrong structurally, or if it's in the compilation from WGSL to Metal. This code also needs to be more thoroughly tested on other platforms. Perhaps someone who has written FlashAttention for one of the other backends could take a look at the shader and see if it looks reasonable, e.g., @jeffbolznv?
  • Otherwise, this passes all the backend tests on my machine, so I think it's in a good state to merge as an initial implementation that can be improved upon as time goes on.

* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though

* neg passes backend test

* unary operators pass ggml tests

* rms_norm double declaration bug atoned

* abides by editor-config

* removed vestigial files

* fixed autoconfig

* All operators (inlcluding xielu) working

* removed unnecesarry checking if node->src[1] exists for unary operators

* responded and dealt with PR comments

* implemented REPL_Template support and removed bug in unary operators kernel

* formatted embed wgsl and ggml-webgpu.cpp

* Faster tensors (#8)

Add fast matrix and matrix/vector multiplication.

* Use map for shader replacements instead of pair of strings

* Wasm (#9)

* webgpu : fix build on emscripten

* more debugging stuff

* test-backend-ops: force single thread on wasm

* fix single-thread case for init_tensor_uniform

* use jspi

* add pthread

* test: remember to set n_thread for cpu backend

* Add buffer label and enable dawn-specific toggles to turn off some checks

* Intermediate state

* Fast working f16/f32 vec4

* Working float fast mul mat

* Clean up naming of mul_mat to match logical model, start work on q mul_mat

* Setup for subgroup matrix mat mul

* Basic working subgroup matrix

* Working subgroup matrix tiling

* Handle weirder sg matrix sizes (but still % sg matrix size)

* Working start to gemv

* working f16 accumulation with shared memory staging

* Print out available subgroup matrix configurations

* Vectorize dst stores for sg matrix shader

* Gemv working scalar

* Minor set_rows optimization (#4)

* updated optimization, fixed errors

* non vectorized version now dispatches one thread per element

* Simplify

* Change logic for set_rows pipelines

---------

Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan>
Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Comment on dawn toggles

* Working subgroup matrix code for (semi)generic sizes

* Remove some comments

* Cleanup code

* Update dawn version and move to portable subgroup size

* Try to fix new dawn release

* Update subgroup size comment

* Only check for subgroup matrix configs if they are supported

* Add toggles for subgroup matrix/f16 support on nvidia+vulkan

* Make row/col naming consistent

* Refactor shared memory loading

* Move sg matrix stores to correct file

* Working q4_0

* Formatting

* Work with emscripten builds

* Fix test-backend-ops emscripten for f16/quantized types

* Use emscripten memory64 to support get_memory

* Add build flags and try ci

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>

* Remove extra whitespace

* Move wasm single-thread logic out of test-backend-ops for cpu backend

* Disable multiple threads for emscripten single-thread builds in ggml_graph_plan

* Refactored pipelines and workgroup calculations (#10)

* refactored pipelines

* refactored workgroup calculation

* removed commented out block of prior maps

* Clean up ceiling division pattern

---------

Co-authored-by: Neha Abbas <nehaabbas@eduroam-169-233-141-223.ucsc.edu>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on flash attention

* Shader structure set up (many bugs still)

* debugging

* Working first test

* Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32

* Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling

* Start work on integrating pre-wgsl

* Separate structs/initial shader compilation library into separate files

* Work on compilation choices for flashattention

* Work on subgroup matrix/tile size portability

* subgroup size agnostic online softmax

* Cleanups, quantization types

* more cleanup

* fix wasm build

* Refactor flashattention to increase parallelism, use direct loads for KV in somce cases

* Checkpoint

* formatting
@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Jan 5, 2026
@jeffbolznv
Copy link
Collaborator

Hi, I only skimmed it and the basic structure looks reasonable, but I have no experience tuning for Apple/Metal. For NVIDIA, the current matmul size of 8x8x8 isn't supported and the tile/workgroup size is probably smaller than you'd want.

@reeselevine
Copy link
Collaborator Author

Thanks Jeff! Yeah the #defines in the shader code itself are just defaults, they can be overridden when processing the shader, and depending on the device's supported subgroup matrix sizes. Finding the right values and building a selection process into the code is one of the next steps for sure on this shader!

@reeselevine
Copy link
Collaborator Author

Also, since the Github-hosted runners don't support the features (subgroup matrices in particular) that are necessary to actually run the WebGPU flashattention implementation, maybe it would be useful to add a ggml-ci node to run the tests too, if it can support the necessary features? I'm not sure exactly how to do this, I see where the ggml-ci nodes are defined in build.yml, but I don't see how the environment is set up to make sure things run (specifically installing the Dawn WebGPU runtime on the machine, like I do right now for the Github-hosted runners).

@ggerganov
Copy link
Member

ggerganov commented Jan 6, 2026

For the CI, we can add a workflow that executes on the Mac self-hosted runner, such as this. It's a Mac M4 Mini machine - would it have the necessary features?

but I don't see how the environment is set up to make sure things run (specifically installing the Dawn WebGPU runtime on the machine, like I do right now for the Github-hosted runners).

I can pre-install Dawn and the workflow will just assume it is installed at some path? A better option is to add the same Dawn installation step to the new self-hosted workflow and then pass the necessary path(s) to the ci/run.sh script via new GG_BUILD_... vars. Let me know if this makes sense.

@reeselevine reeselevine requested a review from CISC as a code owner January 6, 2026 19:48
@github-actions github-actions bot added the devops improvements to build systems and github actions label Jan 6, 2026
@reeselevine
Copy link
Collaborator Author

@ggerganov added a ggml-ci workflow, looks like everything is passing now. The Mac M4 Mini does look like it supports the necessary features, and the full CI actually did help me fix a precision issue in the flash attention code.

@ggerganov
Copy link
Member

ggerganov commented Jan 8, 2026

Alright. If you want, you can also add a webgpu workflow on the seff-hosted CUDA runner:

runs-on: [self-hosted, Linux, X64, NVIDIA]
steps:

It uses a Tesla T4 GPU. We can also do it later in another PR.

@ggerganov
Copy link
Member

Even then, performance is not great right now (< 50% of the same Metal code).

Is this valid for both batch size = 1 and batch size > 1? In the Metal backend, we have 2 separate kernels for flash attention: one for small batch size (a.k.a the vec kernel) and one for larger batch sizes. While here you have the same kernel for both cases.

@reeselevine
Copy link
Collaborator Author

Alright. If you want, you can also add a webgpu workflow on the seff-hosted CUDA runner:

runs-on: [self-hosted, Linux, X64, NVIDIA]
steps:

It uses a Tesla T4 GPU. We can also do it later in another PR.

sounds good, I will plan on adding this once I get a chance to run/test things on an NVIDIA GPU. In particular I'd like to extend the code to support more tensor core sizes.

Even then, performance is not great right now (< 50% of the same Metal code).

Is this valid for both batch size = 1 and batch size > 1? In the Metal backend, we have 2 separate kernels for flash attention: one for small batch size (a.k.a the vec kernel) and one for larger batch sizes. While here you have the same kernel for both cases.

I was mostly testing on batch size > 1. I agree we'll eventually want a separate kernel for the vec path for increased parallelism. It's looking like flashattention in general is a good stress test of WebGPU's capabilities/efficiency, we'll keep working on it.

@reeselevine reeselevine merged commit 15bff84 into ggml-org:master Jan 8, 2026
73 of 76 checks passed
gary149 pushed a commit to gary149/llama-agent that referenced this pull request Jan 13, 2026
* FlashAttention (#13)

* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though

* neg passes backend test

* unary operators pass ggml tests

* rms_norm double declaration bug atoned

* abides by editor-config

* removed vestigial files

* fixed autoconfig

* All operators (inlcluding xielu) working

* removed unnecesarry checking if node->src[1] exists for unary operators

* responded and dealt with PR comments

* implemented REPL_Template support and removed bug in unary operators kernel

* formatted embed wgsl and ggml-webgpu.cpp

* Faster tensors (#8)

Add fast matrix and matrix/vector multiplication.

* Use map for shader replacements instead of pair of strings

* Wasm (#9)

* webgpu : fix build on emscripten

* more debugging stuff

* test-backend-ops: force single thread on wasm

* fix single-thread case for init_tensor_uniform

* use jspi

* add pthread

* test: remember to set n_thread for cpu backend

* Add buffer label and enable dawn-specific toggles to turn off some checks

* Intermediate state

* Fast working f16/f32 vec4

* Working float fast mul mat

* Clean up naming of mul_mat to match logical model, start work on q mul_mat

* Setup for subgroup matrix mat mul

* Basic working subgroup matrix

* Working subgroup matrix tiling

* Handle weirder sg matrix sizes (but still % sg matrix size)

* Working start to gemv

* working f16 accumulation with shared memory staging

* Print out available subgroup matrix configurations

* Vectorize dst stores for sg matrix shader

* Gemv working scalar

* Minor set_rows optimization (#4)

* updated optimization, fixed errors

* non vectorized version now dispatches one thread per element

* Simplify

* Change logic for set_rows pipelines

---------

Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan>
Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Comment on dawn toggles

* Working subgroup matrix code for (semi)generic sizes

* Remove some comments

* Cleanup code

* Update dawn version and move to portable subgroup size

* Try to fix new dawn release

* Update subgroup size comment

* Only check for subgroup matrix configs if they are supported

* Add toggles for subgroup matrix/f16 support on nvidia+vulkan

* Make row/col naming consistent

* Refactor shared memory loading

* Move sg matrix stores to correct file

* Working q4_0

* Formatting

* Work with emscripten builds

* Fix test-backend-ops emscripten for f16/quantized types

* Use emscripten memory64 to support get_memory

* Add build flags and try ci

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>

* Remove extra whitespace

* Move wasm single-thread logic out of test-backend-ops for cpu backend

* Disable multiple threads for emscripten single-thread builds in ggml_graph_plan

* Refactored pipelines and workgroup calculations (#10)

* refactored pipelines

* refactored workgroup calculation

* removed commented out block of prior maps

* Clean up ceiling division pattern

---------

Co-authored-by: Neha Abbas <nehaabbas@eduroam-169-233-141-223.ucsc.edu>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on flash attention

* Shader structure set up (many bugs still)

* debugging

* Working first test

* Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32

* Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling

* Start work on integrating pre-wgsl

* Separate structs/initial shader compilation library into separate files

* Work on compilation choices for flashattention

* Work on subgroup matrix/tile size portability

* subgroup size agnostic online softmax

* Cleanups, quantization types

* more cleanup

* fix wasm build

* Refactor flashattention to increase parallelism, use direct loads for KV in somce cases

* Checkpoint

* formatting

* Update to account for default kv cache padding

* formatting shader

* Add workflow for ggml-ci webgpu

* Try passing absolute path to dawn in ggml-ci

* Avoid error on device destruction, add todos for proper cleanup

* Fix unused warning

* Forgot one parameter unused

* Move some flashattn computation to f32 for correctness
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

devops improvements to build systems and github actions ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants