Added test file for FA3 implementation of SDPA#172671
Added test file for FA3 implementation of SDPA#172671howardzhang-cv wants to merge 5 commits intogh/howardzhang-cv/8/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172671
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b56ea73 with merge base cee5acf ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Added test/nn/attention/test_fa3.py which includes: flash3 versus math implementation checks (fwd and bwd) flash3 kernel called checks flash3 compiled mode metadata flash3 compiled mode versus math implementation checks (fwd and bwd) preserves query layout checks flash3 fp8 forward (w/ and w/out descale) flash3 fp8 backward warning check flash3 fp8 forward compiled Test Plan: python test/nn/attention/test_fa3.py ghstack-source-id: d00876e Pull-Request: #172671
drisspg
left a comment
There was a problem hiding this comment.
looks like alot of cargo cult from other test file which is good maybe see if there are pieces that are very easily reusable between this fa4 and test_transformers even which has a bulk of sdap tests and make a lil more DRY if possible
Summary: Added test/nn/attention/test_fa3.py which includes: flash3 versus math implementation checks (fwd and bwd) flash3 kernel called checks flash3 compiled mode metadata flash3 compiled mode versus math implementation checks (fwd and bwd) preserves query layout checks flash3 fp8 forward (w/ and w/out descale) flash3 fp8 backward warning check flash3 fp8 forward compiled Test Plan: python test/nn/attention/test_fa3.py ghstack-source-id: f4c6057 Pull-Request: #172671
|
@claude can you review this one please look for any cargo cult patterns that shouldnt apply |
|
Claude finished @drisspg's task in 1m 17s —— View job PR Review: Cargo Cult Patterns in test_fa3.py
I've compared 🔴 Major DRY Violations - Identical Code Between test_fa3.py and test_fa4.pyThe following code is nearly 100% identical between both files and should be extracted to a common utility:
🟡 Potential Cargo Cult Patterns
🟢 FA3-Specific Code That Looks CorrectThe FP8 tests (lines 524-782) appear to be FA3-specific functionality and are appropriately not duplicated from FA4. 📋 RecommendationsPer Driss's review comment, I'd suggest:
This would reduce the ~350 lines of duplicated code significantly. |
Summary: Added test/nn/attention/test_fa3.py which includes: flash3 versus math implementation checks (fwd and bwd) flash3 kernel called checks flash3 compiled mode metadata flash3 compiled mode versus math implementation checks (fwd and bwd) preserves query layout checks flash3 fp8 forward (w/ and w/out descale) flash3 fp8 backward warning check flash3 fp8 forward compiled Added test/nn/attention/_fa_test_common.py with common tests between test_fa3.py and test_fa4.py Test Plan: python test/nn/attention/test_fa3.py ghstack-source-id: 18fdeac Pull-Request: #172671
Makes sense, I used a lot of the same tests as fa4, so I put it in a common file and referenced it in both test files. I didn't want to mess with the test_transformers file too much since that's less experimental |
Summary: Added test/nn/attention/test_fa3.py which includes: flash3 versus math implementation checks (fwd and bwd) flash3 kernel called checks flash3 compiled mode metadata flash3 compiled mode versus math implementation checks (fwd and bwd) preserves query layout checks flash3 fp8 forward (w/ and w/out descale) flash3 fp8 backward warning check flash3 fp8 forward compiled Added test/nn/attention/_fa_test_common.py with common tests between test_fa3.py and test_fa4.py Test Plan: python test/nn/attention/test_fa3.py ghstack-source-id: aaeb46a Pull-Request: pytorch/pytorch#172671
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: Added test/nn/attention/test_fa3.py which includes: flash3 versus math implementation checks (fwd and bwd) flash3 kernel called checks flash3 compiled mode metadata flash3 compiled mode versus math implementation checks (fwd and bwd) preserves query layout checks flash3 fp8 forward (w/ and w/out descale) flash3 fp8 backward warning check flash3 fp8 forward compiled Test Plan: python test/nn/attention/test_fa3.py Pull Request resolved: pytorch#172671 Approved by: https://github.com/drisspg
Summary: Added benchmark file for FA3 SDPA Compares FA3 to FA2 fp16, bf16 Compares FA3 bf16 to FA3 fp8 Test Plan: python benchmarks/transformer/sdpa_fa3.py Pull Request resolved: #173026 Approved by: https://github.com/drisspg ghstack dependencies: #172671
Stack from ghstack (oldest at bottom):
Summary: Added test/nn/attention/test_fa3.py which includes:
flash3 versus math implementation checks (fwd and bwd)
flash3 kernel called checks
flash3 compiled mode metadata
flash3 compiled mode versus math implementation checks (fwd and bwd)
preserves query layout checks
flash3 fp8 forward (w/ and w/out descale)
flash3 fp8 backward warning check
flash3 fp8 forward compiled
Test Plan: python test/nn/attention/test_fa3.py