Skip to content

Enable XPU path for FlexAttention#143553

Closed
liangan1 wants to merge 133 commits intopytorch:mainfrom
liangan1:liangan1/flex_attention
Closed

Enable XPU path for FlexAttention#143553
liangan1 wants to merge 133 commits intopytorch:mainfrom
liangan1:liangan1/flex_attention

Conversation

@liangan1
Copy link
Contributor

@liangan1 liangan1 commented Dec 19, 2024

#RFC153024

Motivation

  1. The Attention has been the critical performance bottleneck in the current LLM models, and FlexAttention is a good choice to cover the broad variants in the transformers series models. With FlexAttention, it is easy for us to enable the paged attention and fused SDPA in the transformers repo on XPU device. Besides, it also provide a candidate to process attention in LLM ecosystem libraries ., e.g., vLLM, SGLang on XPU device.
  2. FlexAttention is good start point to push the intel triton based GEMM kernel to be matured. FlexAttention provide both flexattention kernel and flexdecoding kernel to cover both compute bound and memory bound GEMM computation, and different shapes should also been supported to serve LLM inference., e.g. head_dim=64, 96, 128, 256.

What does this PR do?

  1. Enable the device type for Flexattention kernel and UTs to ensure all important UTs pass on XPU device.
  2. For E2E model inference, ensure the functionality of LLM models inference with FlexAttention to be ready.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela @yf225 @ColinPeppler @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143553

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 4 Unrelated Failures

As of commit 29dbb36 with merge base d153af7 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Dec 19, 2024

@liangan1 liangan1 marked this pull request as draft December 19, 2024 04:39
@EikanWang EikanWang added topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module ciflow/xpu Run XPU CI tasks labels Dec 24, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 24, 2024

To add the ciflow label ciflow/xpu please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Dec 24, 2024
@EikanWang EikanWang self-requested a review December 24, 2024 02:14
@EikanWang EikanWang added the ciflow/xpu Run XPU CI tasks label Dec 24, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 24, 2024

To add the ciflow label ciflow/xpu please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Dec 24, 2024
@liangan1
Copy link
Contributor Author

@pytorchbot rebase

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 10, 2025

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

enable flex attention TMA flag on xpu by default
@jianan-gu
Copy link
Contributor

@jianan-gu can you help to see the following cpu related UT fails which is irrelevant to the XPU.
'test/inductor/test_flex_attention.py::TestFlexAttentionCPU::test_flex_attention_stride_ordering_mode_paged_attention_permute_order3_shape0_cpu', 'test/inductor/test_flex_attention.py::TestFlexAttentionCPU::test_flex_attention_stride_ordering_mode_paged_attention_permute_order3_shape1_cpu'

Logs: https://hud.pytorch.org/pytorch/pytorch/pull/143553?sha=8fe1251b2a91d249b553fdf78e089de8f9145f46

Hi, @liangan1 @hoshibara

Thanks for mentioning this and refinements in this PR for FlexAttention related UTs on common devices.
Yes, we have also met this issue when doing similar refinements in #159835.
We have double checked this UT at local while not able to reproduce the failure, and we will take some more time to follow up for final fix. Thus, as this is also irrelevant to the XPU, we suggest you to skip this UT in this PR.
Thanks.

cc @Valentine233

@EikanWang
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / build

Details for Dev Infra team Raised by workflow job

# USE TMA = false by default
cur_kernel_options.setdefault("USE_TMA", False)

if torch.xpu.is_available():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensor descriptor does not support all Q, K, V memory layout. The Q, K, V has to be contiguous on last dim to use the tensor descriptor.

Use the can_use_tma in the condition as if torch.xpu.is_available() and can_use_tma(query, key, value):.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@hoshibara
Copy link
Contributor

@pytorchbot label ciflow/xpu

@hoshibara
Copy link
Contributor

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 28, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@hoshibara
Copy link
Contributor

@pytorchbot label ciflow/trunk

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 28, 2025

To add these label(s) (ciflow/trunk) to the PR, please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@hoshibara
Copy link
Contributor

Hi @EikanWang
The TMA code has been reverted. TMA related UT shows that USE_TMA flag can be interpreted correctly on XPU.

@EikanWang
Copy link
Collaborator

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 4 checks: xpu / linux-jammy-xpu-n-py3.10 / test (default, 1, 8, linux.idc.xpu), xpu / linux-jammy-xpu-n-py3.10 / test (default, 5, 8, linux.idc.xpu), xpu / linux-jammy-xpu-n-py3.10 / test (default, 3, 8, linux.idc.xpu), xpu / linux-jammy-xpu-n-py3.10 / test (default, 6, 8, linux.idc.xpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here


if not _has_sufficient_memory(_device, size_bytes):
# TODO: Memory availability checks for Intel GPU
if device != "xpu" and not _has_sufficient_memory(_device, size_bytes):
Copy link
Collaborator

@guangyey guangyey Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here changed the logic of largeTensorTest. It disabled largeTensorTest on XPU device which results in the failure of python test/dynamo/test_aot_autograd_cache.py AOTAutogradCacheTests.test_autograd_inductor_guards_device_xpu_float16_requires_grad_True.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In e6ae1ed, we attempted to complete the sufficient memory check for XPU, but it caused some previously skipped cases to fail.
This issue needs a new PR to fix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hoshibara Please fix those failures ASAP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise #162034 for fixing this case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guangye‘s PR #161988 will fix this issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hoshibara Thanks. PR landed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks keep-going Don't stop on first failure, keep running tests until the end Merged module: dynamo module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.