Skip to content

[PyTorch] Enable head dim 256 for FA4#2932

Merged
sudhakarsingh27 merged 9 commits into
NVIDIA:mainfrom
yaox12:xiny/headdim256_fa
May 27, 2026
Merged

[PyTorch] Enable head dim 256 for FA4#2932
sudhakarsingh27 merged 9 commits into
NVIDIA:mainfrom
yaox12:xiny/headdim256_fa

Conversation

@yaox12

@yaox12 yaox12 commented Apr 27, 2026

Copy link
Copy Markdown
Member

Description

Need FA4 version 4.0.0b11.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 marked this pull request as draft April 27, 2026 09:31
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from bdcc02e to 3b3f7d0 Compare April 27, 2026 09:31
@greptile-apps

greptile-apps Bot commented Apr 27, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR enables head_dim=256 support for FlashAttention 4 on SM100/SM103 GPUs by delegating head-dimension validation to FA4's own _validate_head_dims function instead of maintaining a parallel static guard in TE, and bumps the required FA4 version to 4.0.0b11.

  • backends.py: _validate_head_dims is imported alongside flash_attn_func/flash_attn_varlen_func in a single grouped import; if absent in an older FA4 install, an uncaught ImportError crashes the entire module load (previously flagged).
  • utils.py: Replaces the static per-arch head-dim check with a live call to FA4's validator; adds an SM100 cross-attention fallback for hd256 shapes; the MLA misalignment workaround is preserved as independent if checks; v4_installation_steps is correctly updated to 4.0.0b11.
  • test_attention.py: Adds test_dpa_fa4_hdim256 with an explicit SM100/SM103 skipif guard, and removes stale cuDNN version checks from all FA4 tests.

Confidence Score: 4/5

The core logic in utils.py is sound, but the grouped import in backends.py will crash the entire TE module load for any user who has FA4 installed at a version older than 4.0.0b11.

The import in backends.py bundles _validate_head_dims into the same grouped block as the two core FA4 functions. Any FA4 install older than 4.0.0b11 that lacks this symbol triggers an unhandled ImportError at module load time, making TE unusable for those users.

transformer_engine/pytorch/attention/dot_product_attention/backends.py — the grouped FA4 import is the critical path that warrants a second look before merging.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds _validate_head_dims to the same grouped import as flash_attn_func/flash_attn_varlen_func; an ImportError on older FA4 (pre-4.0.0b11) crashes the entire module load rather than gracefully falling back.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Replaces static head-dim guard with a live call to FA4's _validate_head_dims; adds SM100 cross-attention fallback for hd256; MLA workaround restructured as independent if checks; v4_installation_steps updated to 4.0.0b11.
tests/pytorch/attention/test_attention.py Adds dedicated test_dpa_fa4_hdim256 with explicit SM100/SM103 skip guard; removes stale cuDNN version checks from all FA4 tests.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[get_attention_backend called] --> B{use_flash_attention_4 and v4_is_installed and v4_validate_head_dims is not None?}
    B -- No --> Z[Skip FA4 head-dim check]
    B -- Yes --> C[Compute _fa4_alignment]
    C --> D[Call v4_validate_head_dims]
    D -- AssertionError --> E[Disable FA4]
    D -- OK --> F{SM100 AND hd256 AND seqlen_q != seqlen_kv?}
    F -- Yes --> G[Disable FA4 cross-attn hd256]
    F -- No --> H{Training AND MLA AND SM100?}
    H -- Yes --> I[gcd misalignment check]
    I -- Misaligned --> J[Disable FA4 MLA bwd]
    I -- OK --> K[FA4 enabled]
    H -- No --> K
Loading

Reviews (8): Last reviewed commit: "Merge branch 'main' into xiny/headdim256..." | Re-trigger Greptile

Comment thread tests/pytorch/attention/test_attention.py Outdated
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from 3b3f7d0 to 9a93156 Compare May 6, 2026 02:44
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from ae74e44 to 8aa5242 Compare May 6, 2026 02:55
@yaox12

yaox12 commented May 6, 2026

Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

@yaox12 yaox12 marked this pull request as ready for review May 6, 2026 02:59
@yaox12

yaox12 commented May 6, 2026

Copy link
Copy Markdown
Member Author

@vcherepanov-nv @KshitijLakhani Please review.

@KshitijLakhani KshitijLakhani requested a review from mk-61 May 8, 2026 06:34
Comment thread tests/pytorch/attention/test_attention.py Outdated
Comment thread tests/pytorch/attention/test_attention.py
# dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are
# misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's
# not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890.
if (

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this still be checked when FlashAttentionUtils.v4_validate_head_dims == None?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked that this is a bug of FA4. Kernels produce wrong results on these shapes but they're allowed by v4_validate_head_dims, so we have to filter them out manually.
Raise an issue to FA4. Dao-AILab/flash-attention#2552

@vcherepanov-nv

Copy link
Copy Markdown
Collaborator

LGTM

yaox12 added 2 commits May 10, 2026 22:28
@yaox12

yaox12 commented May 11, 2026

Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

yaox12 added 2 commits May 12, 2026 10:30
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12

yaox12 commented May 12, 2026

Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

@yaox12 yaox12 requested a review from cyanguwa as a code owner May 13, 2026 10:42
@yaox12

yaox12 commented May 15, 2026

Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

@sudhakarsingh27 sudhakarsingh27 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@sudhakarsingh27 sudhakarsingh27 self-requested a review May 15, 2026 20:52
@yaox12

yaox12 commented May 18, 2026

Copy link
Copy Markdown
Member Author

B200 test failed with 1 element mismatch. It should be irrelevant to this PR because I saw similar errors in other pipelines.

@sudhakarsingh27

sudhakarsingh27 commented May 18, 2026

Copy link
Copy Markdown
Member

Need to manually run L1 tests, triggering now
Doesn't look like it's needed

Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12

yaox12 commented May 24, 2026

Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

@sudhakarsingh27 sudhakarsingh27 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sudhakarsingh27 sudhakarsingh27 merged commit 5f1eaff into NVIDIA:main May 27, 2026
25 of 27 checks passed
KshitijLakhani pushed a commit that referenced this pull request May 28, 2026
* enable head dim 256 for FA4

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update CI, fix lint, resolve comments

Signed-off-by: Xin Yao <xiny@nvidia.com>

* resolve comments

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update filter

Signed-off-by: Xin Yao <xiny@nvidia.com>

---------

Signed-off-by: Xin Yao <xiny@nvidia.com>
Baibaifan pushed a commit to Baibaifan/TransformerEngine that referenced this pull request Jun 1, 2026
* enable head dim 256 for FA4

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update CI, fix lint, resolve comments

Signed-off-by: Xin Yao <xiny@nvidia.com>

* resolve comments

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update filter

Signed-off-by: Xin Yao <xiny@nvidia.com>

---------

Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Baibaifan pushed a commit to Baibaifan/TransformerEngine that referenced this pull request Jun 1, 2026
* enable head dim 256 for FA4

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update CI, fix lint, resolve comments

Signed-off-by: Xin Yao <xiny@nvidia.com>

* resolve comments

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update filter

Signed-off-by: Xin Yao <xiny@nvidia.com>

---------

Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants