Skip to content

Added test file for FA3 implementation of SDPA#172671

Closed
howardzhang-cv wants to merge 5 commits intogh/howardzhang-cv/8/basefrom
gh/howardzhang-cv/8/head
Closed

Added test file for FA3 implementation of SDPA#172671
howardzhang-cv wants to merge 5 commits intogh/howardzhang-cv/8/basefrom
gh/howardzhang-cv/8/head

Conversation

@howardzhang-cv
Copy link
Copy Markdown
Contributor

@howardzhang-cv howardzhang-cv commented Jan 16, 2026

Stack from ghstack (oldest at bottom):

Summary: Added test/nn/attention/test_fa3.py which includes:
flash3 versus math implementation checks (fwd and bwd)
flash3 kernel called checks
flash3 compiled mode metadata
flash3 compiled mode versus math implementation checks (fwd and bwd)
preserves query layout checks
flash3 fp8 forward (w/ and w/out descale)
flash3 fp8 backward warning check
flash3 fp8 forward compiled

Test Plan: python test/nn/attention/test_fa3.py

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 16, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172671

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b56ea73 with merge base cee5acf (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

howardzhang-cv added a commit that referenced this pull request Jan 16, 2026
Summary: Added test/nn/attention/test_fa3.py which includes:
flash3 versus math implementation checks (fwd and bwd)
flash3 kernel called checks
flash3 compiled mode metadata
flash3 compiled mode versus math implementation checks (fwd and bwd)
preserves query layout checks
flash3 fp8 forward (w/ and w/out descale)
flash3 fp8 backward warning check
flash3 fp8 forward compiled

Test Plan: python test/nn/attention/test_fa3.py


ghstack-source-id: d00876e
Pull-Request: #172671
@pytorch-bot pytorch-bot Bot added the topic: not user facing topic category label Jan 16, 2026
@howardzhang-cv howardzhang-cv added the release notes: nn release notes category label Jan 16, 2026
Comment thread test/nn/attention/test_fa3.py Outdated
Copy link
Copy Markdown
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like alot of cargo cult from other test file which is good maybe see if there are pieces that are very easily reusable between this fa4 and test_transformers even which has a bulk of sdap tests and make a lil more DRY if possible

[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Jan 21, 2026
Summary: Added test/nn/attention/test_fa3.py which includes:
flash3 versus math implementation checks (fwd and bwd)
flash3 kernel called checks
flash3 compiled mode metadata
flash3 compiled mode versus math implementation checks (fwd and bwd)
preserves query layout checks
flash3 fp8 forward (w/ and w/out descale)
flash3 fp8 backward warning check
flash3 fp8 forward compiled

Test Plan: python test/nn/attention/test_fa3.py

ghstack-source-id: f4c6057
Pull-Request: #172671
@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Jan 21, 2026

@claude can you review this one please look for any cargo cult patterns that shouldnt apply

@claude
Copy link
Copy Markdown

claude Bot commented Jan 21, 2026

Claude finished @drisspg's task in 1m 17s —— View job


PR Review: Cargo Cult Patterns in test_fa3.py

  • Read the test file and understand the structure
  • Identify cargo cult patterns that don't apply to FA3
  • Check for code that could be shared/made DRY
  • Provide detailed review feedback

I've compared test_fa3.py with test_fa4.py and analyzed the code for cargo cult patterns. Here are my findings:

🔴 Major DRY Violations - Identical Code Between test_fa3.py and test_fa4.py

The following code is nearly 100% identical between both files and should be extracted to a common utility:

  1. SdpaShape namedtuple (line 28)

    • Exact duplicate in both files
  2. cuda_kernel_profiler context manager (lines 44-54)

    • Nearly identical, with minor differences:
      • FA3 version includes ProfilerActivity.CPU in addition to CUDA
      • FA4 filters by evt.device_type == torch.autograd.DeviceType.CUDA
    • The CPU activity inclusion in FA3 may be cargo cult - Fix this →
  3. flash_vs_math function (lines 57-123)

    • Exactly identical - 67 lines of duplicated code
  4. DummyHandle class (lines 126-132)

    • Identical helper class for testing
  5. _assert_flash_matches_math method (lines 144-209)

    • Identical except for skipUnless decorator
  6. test_flash_attention_matches_math (lines 227-237)

    • Nearly identical, but FA4 has an extra test_backward = head_dim == 128 and dtype == torch.float16 condition due to a known issue. FA3 always tests backward.
  7. test_*_kernel_called (lines 241-281)

    • Similar structure but different kernel patterns
  8. test_multiple_activate (lines 284-321)

    • Identical except for the final activate_flash_attention_impl("FA3/FA4") calls
  9. test_compiled_sdpa_*_metadata (lines 325-349)

    • Identical
  10. test_compiled_sdpa_*_matches_math (lines 353-401)

    • Identical
  11. test_compiled_sdpa_*_backward_matches_math (lines 405-484)

    • Nearly identical. FA3 adds min_atol handling (lines 467, 473-476) which FA4 lacks
  12. test_attention_preserves_query_layout (lines 487-522)

    • Identical

🟡 Potential Cargo Cult Patterns

  1. SM90 check in _fa3_dependencies_available (line 35):

    if major != 9:  # FA3 requires Hopper (SM90)

    FA4 checks major not in (9, 10) (supports both Hopper and Blackwell). Verify if FA3 truly only supports SM90 or if this is overly restrictive.

  2. Hardcoded rtol values (multiple locations):

    • rtol=2 appears throughout both files. If this tolerance is correct for one implementation, it may not be optimal for another.
  3. min_atol in backward test (lines 467, 473-476 in FA3):

    min_atol = 0.05 if dtype == torch.bfloat16 else 0.01

    FA3 adds this but FA4 doesn't have it. Either FA4 should also have this, or it's unnecessary cargo cult.

🟢 FA3-Specific Code That Looks Correct

The FP8 tests (lines 524-782) appear to be FA3-specific functionality and are appropriately not duplicated from FA4.

📋 Recommendations

Per Driss's review comment, I'd suggest:

  1. Create a common test utilities module at test/nn/attention/_sdpa_test_utils.py:

    • Move SdpaShape, cuda_kernel_profiler, flash_vs_math, DummyHandle
    • Consider creating a base test class SDPATestMixin with shared test methods
  2. Parameterize the implementation-specific parts:

    • Dependency check function
    • Kernel pattern strings
    • Activation function name ("FA3" vs "FA4")
  3. Example refactor structure:

    # test/nn/attention/_sdpa_test_utils.py
    SdpaShape = namedtuple("Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"])
    
    @contextmanager
    def cuda_kernel_profiler(kernel_pattern): ...
    
    def flash_vs_math(test_case, q, k, v, is_causal=False, rtol=2): ...
    
    class DummyHandle: ...
    # test/nn/attention/test_fa3.py
    from _sdpa_test_utils import SdpaShape, cuda_kernel_profiler, flash_vs_math, DummyHandle

This would reduce the ~350 lines of duplicated code significantly.

[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Jan 21, 2026
Summary: Added test/nn/attention/test_fa3.py which includes:
flash3 versus math implementation checks (fwd and bwd)
flash3 kernel called checks
flash3 compiled mode metadata
flash3 compiled mode versus math implementation checks (fwd and bwd)
preserves query layout checks
flash3 fp8 forward (w/ and w/out descale)
flash3 fp8 backward warning check
flash3 fp8 forward compiled

Added test/nn/attention/_fa_test_common.py with common tests between test_fa3.py and test_fa4.py

Test Plan: python test/nn/attention/test_fa3.py

ghstack-source-id: 18fdeac
Pull-Request: #172671
[ghstack-poisoned]
@howardzhang-cv
Copy link
Copy Markdown
Contributor Author

looks like alot of cargo cult from other test file which is good maybe see if there are pieces that are very easily reusable between this fa4 and test_transformers even which has a bulk of sdap tests and make a lil more DRY if possible

Makes sense, I used a lot of the same tests as fa4, so I put it in a common file and referenced it in both test files. I didn't want to mess with the test_transformers file too much since that's less experimental

suncapitalllc007-star pushed a commit to suncapitalllc007-star/pytorch that referenced this pull request Jan 25, 2026
Summary: Added test/nn/attention/test_fa3.py which includes:
flash3 versus math implementation checks (fwd and bwd)
flash3 kernel called checks
flash3 compiled mode metadata
flash3 compiled mode versus math implementation checks (fwd and bwd)
preserves query layout checks
flash3 fp8 forward (w/ and w/out descale)
flash3 fp8 backward warning check
flash3 fp8 forward compiled

Added test/nn/attention/_fa_test_common.py with common tests between test_fa3.py and test_fa4.py

Test Plan: python test/nn/attention/test_fa3.py

ghstack-source-id: aaeb46a
Pull-Request: pytorch/pytorch#172671
Comment thread test/nn/attention/_fa_test_common.py Outdated
[ghstack-poisoned]
@howardzhang-cv
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 26, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

riccardofelluga pushed a commit to riccardofelluga/pytorch that referenced this pull request Jan 27, 2026
Summary: Added test/nn/attention/test_fa3.py which includes:
flash3 versus math implementation checks (fwd and bwd)
flash3 kernel called checks
flash3 compiled mode metadata
flash3 compiled mode versus math implementation checks (fwd and bwd)
preserves query layout checks
flash3 fp8 forward (w/ and w/out descale)
flash3 fp8 backward warning check
flash3 fp8 forward compiled

Test Plan: python test/nn/attention/test_fa3.py
Pull Request resolved: pytorch#172671
Approved by: https://github.com/drisspg
pytorchmergebot pushed a commit that referenced this pull request Jan 27, 2026
Summary:
Added benchmark file for FA3 SDPA
Compares FA3 to FA2 fp16, bf16
Compares FA3 bf16 to FA3 fp8

Test Plan: python benchmarks/transformer/sdpa_fa3.py
Pull Request resolved: #173026
Approved by: https://github.com/drisspg
ghstack dependencies: #172671
@github-actions github-actions Bot deleted the gh/howardzhang-cv/8/head branch February 27, 2026 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nn release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants