[FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation#153357
[FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation#153357danielvegamyhre wants to merge 7 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153357
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 2 Unrelated FailuresAs of commit af8fe6a with merge base 27e9d9b ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
e258db7 to
48bc862
Compare
|
cc @drisspg for review. I debated between validation vs enforcement for this fix, there are pros/cons to both but felt like enforcement would be a more seamless UX. |
drisspg
left a comment
There was a problem hiding this comment.
I am wishy washy on if we want to does this under the hood a lil magically...
That being said this fix shouldn't be needed on sm100 and newer which has TT support, can you make sure to keep this change to sm 90 devices
@drisspg I am as well, we could just add validation and require the user to do the transpose themselves, so they are explicitly aware of the layout of their tensors and there's no surprises? It does add some friction to using the API, though.
Good to know - updated. |
torch/nn/attention/flex_attention.py
Outdated
There was a problem hiding this comment.
is this going to work for amd?
There was a problem hiding this comment.
I updated it to also check torch.cuda.is_available() to avoid an error on AMD, but no this memory layout enforcement will not be applied on any AMD hardware. I can run follow up tests to check perf on AMD if we have access to any AMD hardware that supports fp8 gemms?
6212253 to
84b52b0
Compare
84b52b0 to
f9d0895
Compare
torch/nn/attention/flex_attention.py
Outdated
| torch.cuda.is_available() | ||
| and torch.cuda.get_device_capability("cuda") >= (10, 0) |
There was a problem hiding this comment.
| torch.cuda.is_available() | |
| and torch.cuda.get_device_capability("cuda") >= (10, 0) | |
| torch.version.cuda | |
| and torch.cuda.get_device_capability("cuda") >= (10, 0) |
I ran some tests on the AMD machine that I'm on:
>>> torch.cuda.is_available()
True
>>> torch.version.cuda
>>> torch.version.cuda is None
True
>>> torch.version.hip
'6.3.42131-fa1d09cbd'
>>> torch.cuda.get_device_capability("cuda")
(9, 4)So I think you need to check torch.version.hip to be able to reliably tell if you're on an AMD machine
There was a problem hiding this comment.
Thanks for testing this. Seems like could just check if torch.version.cuda is not None to ensure we're on NVIDIA? Updated the code with this check.
davidberard98
left a comment
There was a problem hiding this comment.
lgtm as long as @drisspg is fine with "forcing explicit transpose" instead of "erroring/warning"
test/inductor/test_flex_attention.py
Outdated
| l_block_mask_full_q_num_blocks = L_block_mask_full_q_num_blocks | ||
| l_block_mask_full_q_indices = L_block_mask_full_q_indices | ||
|
|
||
| get_device_capability = torch.cuda.get_device_capability('cuda'); get_device_capability = None |
There was a problem hiding this comment.
I am surprised this shows up in the graph module
There was a problem hiding this comment.
Should probably be marked as a constant, yea. although it will get traced out in tracing so only affects pre grad.
There was a problem hiding this comment.
Weird, it shows up in the graph module when I use:
is_sm_100 = torch.cuda.get_device_capability("cuda") == (10, 0)However, it doesn't show up when I do this:
should_enforce_mem_layout = (
gemm_precision in fp8_dtypes
and torch.version.cuda is not None
and torch.cuda.get_device_capability("cuda") >= (8, 9)
and torch.cuda.get_device_capability("cuda") < (10, 0)
)Using the latter approach now.
|
Im fine with enforcing for now, your perf benchmark above includes the transpose right? |
Yep it includes the transpose. Pretransposed it was slightly faster (~294us IIRC). |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot revert -m "Might have introduced regressions in rocm testing for main: https://github.com/pytorch/pytorch/actions/runs/15035410497/job/42257000513 feel free to re-merge if this was a mistake" -c nosignal |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@danielvegamyhre your PR has been successfully reverted. |
…ention to avoid perf degradation (#153357)" This reverts commit 881a598. Reverted #153357 on behalf of https://github.com/jeanschmidt due to Might have introduced regressions in rocm testing for main: https://github.com/pytorch/pytorch/actions/runs/15035410497/job/42257000513 feel free to re-merge if this was a mistake ([comment](#153357 (comment)))
|
Updated test to handle the different graph module produced for AMD. |
|
|
@pytorchbot merge -f "test failures unrelated, see comments" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #147336
Context
NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.
Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.
To summarize:
In flex attention when performing the FP8 GEMM
softmax_scores @ Vthe right operand V must be in column-major memory layout. However, thetl.loadof V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.This is because triton does not perform async copies with the
cp.asyncPTX instruction if the number of contiguous bytes is less than 4 (see here).i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.
Fix summary
Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.
Before fix:
After fix:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov