Skip to content

Memcpy kernel for flash attention#29

Merged
suquark merged 16 commits intomainfrom
memcpy4flashattn
Apr 11, 2023
Merged

Memcpy kernel for flash attention#29
suquark merged 16 commits intomainfrom
memcpy4flashattn

Conversation

@suquark
Copy link
Copy Markdown
Contributor

@suquark suquark commented Apr 6, 2023

Memcpy kernel for flash attention

num_tokens: 64, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.008 ms
[Throughput] gather_cached_kv: 156.479 GB/s
num_tokens: 128, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.011 ms
[Throughput] gather_cached_kv: 216.171 GB/s
num_tokens: 256, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.032 ms
[Throughput] gather_cached_kv: 152.631 GB/s
num_tokens: 512, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.057 ms
[Throughput] gather_cached_kv: 172.325 GB/s
num_tokens: 1024, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.104 ms
[Throughput] gather_cached_kv: 187.537 GB/s
num_tokens: 2048, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.204 ms
[Throughput] gather_cached_kv: 191.603 GB/s

The performance is pretty good (theoretical optimal throughput is 1.6TB/s for A100-40GB), considering the memory layout is not ideal.

result for unoptimized kernel:

num_tokens: 64, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.010 ms
[Throughput] gather_cached_kv: 125.891 GB/s
num_tokens: 128, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.015 ms
[Throughput] gather_cached_kv: 160.678 GB/s
num_tokens: 256, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.032 ms
[Throughput] gather_cached_kv: 150.732 GB/s
num_tokens: 512, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.060 ms
[Throughput] gather_cached_kv: 162.482 GB/s
num_tokens: 1024, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.108 ms
[Throughput] gather_cached_kv: 180.763 GB/s
num_tokens: 2048, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.206 ms
[Throughput] gather_cached_kv: 189.757 GB/s

the optimized kernel works much better for smaller number of tokens (+20% speedup)

@suquark suquark changed the title Memcpy for flashattn Memcpy kernel for flashattn Apr 6, 2023
@suquark suquark changed the title Memcpy kernel for flashattn Memcpy kernel for flash attention Apr 6, 2023
@suquark
Copy link
Copy Markdown
Contributor Author

suquark commented Apr 6, 2023

implementation is done. need testing (will do it on Thursday)

the memory saving strategy is orthogonal to this kernel, so I would not include it in this PR

@suquark suquark requested a review from WoosukKwon April 6, 2023 08:57
@suquark suquark force-pushed the memcpy4flashattn branch from 678bb06 to 07e9891 Compare April 8, 2023 20:40
optimize with shared memory

better number of threads

update test

temp disable test

update
@suquark suquark force-pushed the memcpy4flashattn branch from 07e9891 to e21845e Compare April 8, 2023 20:45
@WoosukKwon
Copy link
Copy Markdown
Collaborator

Hey @suquark thanks for the PR! I have a quick question: have you also measured the performance diff between the two kernels before and after the optimization?

@suquark suquark closed this Apr 11, 2023
@WoosukKwon WoosukKwon reopened this Apr 11, 2023
@suquark
Copy link
Copy Markdown
Contributor Author

suquark commented Apr 11, 2023

see the PR comment for the optimized kernel performance comparison

Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
It's faster

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
tianyil1 pushed a commit to tianyil1/vllm that referenced this pull request Jun 5, 2024
fxmarty pushed a commit to fxmarty/vllm-public that referenced this pull request Jun 12, 2024
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request Jun 21, 2024
bigPYJ1151 pushed a commit to bigPYJ1151/vllm that referenced this pull request Jul 31, 2024
…ack_acc_bf16

fix linear init impacts on generation
@alixiaodi alixiaodi mentioned this pull request Aug 2, 2024
wuhuikx pushed a commit to wuhuikx/vllm that referenced this pull request Mar 27, 2025
Add official doc index. Move the release content to the right place.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
zyongye pushed a commit to zyongye/vllm that referenced this pull request Aug 5, 2025
* Fix truncated output

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* fix

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

---------

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
zyongye pushed a commit to zyongye/vllm that referenced this pull request Aug 6, 2025
* Fix truncated output

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* fix

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

---------

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
inkcherry pushed a commit to inkcherry/vllm that referenced this pull request Nov 6, 2025
dik654 pushed a commit to dik654/vllm-for-study that referenced this pull request Nov 18, 2025
New Industry Use Cases (vllm-project#21-30):
- vllm-project#21 Game Development: AI game testing + balance tuning
- vllm-project#22 Construction: Vision AI safety inspection
- vllm-project#23 Agriculture/Smart Farm: Crop monitoring + pest detection
- vllm-project#24 Government/Public: Document automation + citizen services
- vllm-project#25 Energy/Utilities: Grid monitoring + anomaly detection
- vllm-project#26 Environment/Sustainability: Carbon tracking + ESG reporting
- vllm-project#27 Fashion/Apparel: Trend analysis + inventory optimization
- vllm-project#28 Sports/Fitness: Performance analytics + tactical analysis
- vllm-project#29 Automotive/Mobility: Autonomous driving simulation
- vllm-project#30 Space/Aerospace: Satellite image analysis

Advanced Architecture Patterns:
1. Event-Driven Pattern: Webhook → Event Bus → Agent triggers
2. Streaming Pattern: Large dataset processing with chunking
3. Batch Processing Pattern: Celery-based parallel processing
4. Circuit Breaker Pattern: Fault tolerance + auto recovery
5. CQRS + Event Sourcing: Command/Query separation
6. Saga Pattern: Distributed transaction management

Guide now covers:
- 30+ industry-specific MCP implementations
- 6 production-ready architecture patterns
- Real-world scalability solutions
- Enterprise integration strategies
- Total: 8,672 lines (from 7,249)
chaojun-zhang pushed a commit to chaojun-zhang/vllm that referenced this pull request Nov 20, 2025
* update ci with new repo name

* update ipex to latest version

* Update ci_pvc.yaml
robertgshaw2-redhat added a commit that referenced this pull request Dec 14, 2025
minosfuture added a commit to minosfuture/vllm that referenced this pull request Dec 23, 2025
…ch) (vllm-project#29…"

This reverts commit f16356f.

Signed-off-by: Ming Yang <minos.future@gmail.com>
soodoshll pushed a commit to soodoshll/vllm that referenced this pull request Jan 30, 2026
* [Docker][Dev] Fix libnccl-dev version for the CUDA 13.0.1 devel image

[Docker][Dev] Fix libnccl-dev version conflict for the CUDA 13.0.1 devel image

Further update

* feat: Support FA4 for mm-encoder-attn-backend for qwen models

* feat: Kernel warmup for vit fa4

* fix: Fix some minor conflicts due to the introduction of flash_attn.cute

* Revert "[Docker][Dev] Fix libnccl-dev version for the CUDA 13.0.1 devel image"

This reverts commit ab76b28.

* chore: Update requirements and revert README.md

* chore: Install git for flash_attn cute installation

* lint: Fix linting

* Revert "[Improvement] Persist CUDA compat libraries paths to prevent reset on `apt-get` (vllm-project#30784)" (vllm-project#31)

This reverts commit 2a60ac9.

---------

Co-authored-by: Shang Wang <shangw@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants