Skip to content

[FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation#153357

Closed
danielvegamyhre wants to merge 7 commits intomainfrom
validate-flex
Closed

[FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation#153357
danielvegamyhre wants to merge 7 commits intomainfrom
validate-flex

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented May 12, 2025

Fixes #147336

Context

NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.

Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.

To summarize:

In flex attention when performing the FP8 GEMM softmax_scores @ V the right operand V must be in column-major memory layout. However, the tl.load of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.

This is because triton does not perform async copies with the cp.async PTX instruction if the number of contiguous bytes is less than 4 (see here).

i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.

Fix summary

  • To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs

Benchmarks

Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.

Before fix:

(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8 
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us

After fix:

(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8 
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented May 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153357

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 2 Unrelated Failures

As of commit af8fe6a with merge base 27e9d9b (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre changed the title [flex attention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation [FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation May 12, 2025
@danielvegamyhre danielvegamyhre requested a review from drisspg May 12, 2025 02:06
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented May 12, 2025

cc @drisspg for review. I debated between validation vs enforcement for this fix, there are pros/cons to both but felt like enforcement would be a more seamless UX.

@albanD albanD requested review from ngimel and removed request for albanD May 12, 2025 14:25
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wishy washy on if we want to does this under the hood a lil magically...

That being said this fix shouldn't be needed on sm100 and newer which has TT support, can you make sure to keep this change to sm 90 devices

@danielvegamyhre
Copy link
Contributor Author

I am wishy washy on if we want to does this under the hood a lil magically...

@drisspg I am as well, we could just add validation and require the user to do the transpose themselves, so they are explicitly aware of the layout of their tensors and there's no surprises? It does add some friction to using the API, though.

That being said this fix shouldn't be needed on sm100 and newer which has TT support, can you make sure to keep this change to sm 90 devices

Good to know - updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this going to work for amd?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated it to also check torch.cuda.is_available() to avoid an error on AMD, but no this memory layout enforcement will not be applied on any AMD hardware. I can run follow up tests to check perf on AMD if we have access to any AMD hardware that supports fp8 gemms?

Comment on lines +1191 to +1192
torch.cuda.is_available()
and torch.cuda.get_device_capability("cuda") >= (10, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.cuda.is_available()
and torch.cuda.get_device_capability("cuda") >= (10, 0)
torch.version.cuda
and torch.cuda.get_device_capability("cuda") >= (10, 0)

I ran some tests on the AMD machine that I'm on:

>>> torch.cuda.is_available()
True
>>> torch.version.cuda
>>> torch.version.cuda is None
True
>>> torch.version.hip
'6.3.42131-fa1d09cbd'
>>> torch.cuda.get_device_capability("cuda")
(9, 4)

So I think you need to check torch.version.hip to be able to reliably tell if you're on an AMD machine

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing this. Seems like could just check if torch.version.cuda is not None to ensure we're on NVIDIA? Updated the code with this check.

Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm as long as @drisspg is fine with "forcing explicit transpose" instead of "erroring/warning"

l_block_mask_full_q_num_blocks = L_block_mask_full_q_num_blocks
l_block_mask_full_q_indices = L_block_mask_full_q_indices

get_device_capability = torch.cuda.get_device_capability('cuda'); get_device_capability = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised this shows up in the graph module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably be marked as a constant, yea. although it will get traced out in tracing so only affects pre grad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, it shows up in the graph module when I use:

is_sm_100 = torch.cuda.get_device_capability("cuda") == (10, 0)

However, it doesn't show up when I do this:

should_enforce_mem_layout = (
        gemm_precision in fp8_dtypes
        and torch.version.cuda is not None
        and torch.cuda.get_device_capability("cuda") >= (8, 9)
        and torch.cuda.get_device_capability("cuda") < (10, 0)
    )

Using the latter approach now.

@drisspg
Copy link
Contributor

drisspg commented May 14, 2025

Im fine with enforcing for now, your perf benchmark above includes the transpose right?

@danielvegamyhre
Copy link
Contributor Author

Im fine with enforcing for now, your perf benchmark above includes the transpose right?

Yep it includes the transpose. Pretransposed it was slightly faster (~294us IIRC).

@danielvegamyhre
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 14, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jeanschmidt
Copy link
Contributor

@pytorchbot revert -m "Might have introduced regressions in rocm testing for main: https://github.com/pytorch/pytorch/actions/runs/15035410497/job/42257000513 feel free to re-merge if this was a mistake" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@danielvegamyhre your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request May 15, 2025
…ention to avoid perf degradation (#153357)"

This reverts commit 881a598.

Reverted #153357 on behalf of https://github.com/jeanschmidt due to Might have introduced regressions in rocm testing for main: https://github.com/pytorch/pytorch/actions/runs/15035410497/job/42257000513 feel free to re-merge if this was a mistake ([comment](#153357 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels May 15, 2025
@danielvegamyhre
Copy link
Contributor Author

Updated test to handle the different graph module produced for AMD.

@davidberard98 davidberard98 added the ciflow/rocm Trigger "default" config CI on ROCm label May 15, 2025
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented May 15, 2025

  • inductor test failing is related, error looks like a transient build/infra issue:
[ RUN      ] AotInductorTest.BasicPackageLoaderTestCpu
unknown file: Failure
C++ exception with description "Error in dlopen: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /tmp/j0sT4j/data/aotinductor/model/ctl2eqddy7dqff6vahjctdrmngtsn77afchsa3gojr2xfxyyysfb.wrapper.so)
  • rocm test failure is unrelated, error is transient test infra issue:
Run # copy test results back to the mounted workspace, needed sudo, resulting permissions were correct
Error: No such container: 
  • sm75 pr time benchmark failure might be related since this memory layout transformation is applied to arch < sm100, and we should really only be applying it to arch >= sm89 and < sm100. Updating this change to not apply to < sm89.

@danielvegamyhre
Copy link
Contributor Author

@pytorchbot merge -f "test failures unrelated, see comments"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the validate-flex branch June 19, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Investigate FlexAttention performance degradation on low precision inputs

7 participants