Skip to content

[Intel GPU] Enable backward for SDPA XPU [WIP]#156272

Closed
LuFinch wants to merge 14 commits intopytorch:mainfrom
LuFinch:lfq/sdpa_traning
Closed

[Intel GPU] Enable backward for SDPA XPU [WIP]#156272
LuFinch wants to merge 14 commits intopytorch:mainfrom
LuFinch:lfq/sdpa_traning

Conversation

@LuFinch
Copy link
Contributor

@LuFinch LuFinch commented Jun 18, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156272

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures, 1 Unrelated Failure

As of commit acbac05 with merge base 908c5cc (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration labels Jun 18, 2025
@LuFinch LuFinch changed the title [Intel GPU] Enable training for SDPA XPU [Intel GPU] Enable training for SDPA XPU [WIP] Jun 18, 2025
@LuFinch
Copy link
Contributor Author

LuFinch commented Jun 18, 2025

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jun 18, 2025
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch from 7014f89 to 232387d Compare June 18, 2025 11:03
@guangyey guangyey moved this to In Progress in PyTorch Intel Jun 19, 2025
@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch 7 times, most recently from 40bafeb to 435e14a Compare June 24, 2025 06:27
@LuFinch LuFinch changed the title [Intel GPU] Enable training for SDPA XPU [WIP] [Intel GPU] Enable backward for SDPA XPU [WIP] Jun 25, 2025
@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch from 672a28f to 4d05632 Compare July 16, 2025 06:32
@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch from 4d05632 to 6de83d4 Compare July 24, 2025 01:56
@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch from be64cfc to c8d7c3b Compare July 30, 2025 03:33
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 12, 2025
Copy link
Collaborator

@EikanWang EikanWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuFinch , the newly added parameter of sdpa backward overridable is not backward compatible.

tags: nondeterministic_seeded

- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, bool compute_log_sumexp=False, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuFinch , why do we need to add comput_log_sumexp? It breaks the ABI backward compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we can check input tensors' attr like compute_logsumexp = query.requires_grad() || key.requires_grad() || value.requires_grad() to decide whether compute logsumexp. And this checking works in eager mode.

However, in torch.compile mode, the input tensors query/key/value require_grad()==True at the beginning would become require_grad()==False in the op after aot_autograd in some models. Hence, it needs a bool flag to indicate this op should compute logsumexp. I am not an expert of aot_autograd and not sure why it acts like this. But cudnn and efficient attention also has this parameter. I guess they meet the same issue, otherwise they should be able to move this check into op.

case SDPBackend::cudnn_attention: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention(
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale);

case SDPBackend::efficient_attention: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
if (attn_mask.has_value()) {
attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);;
}
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale);
return std::get<0>(out_and_lse);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuFinch , why do we need to add comput_log_sumexp? It breaks the ABI backward compatibility.

By default comput_log_sumexp=False, this should not break API-level BC, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eikan recommends me to move this argument as the last argument, then it will not break BC.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already moved this argument as the last argument.

(attn_mask.has_value() && attn_mask.value().requires_grad()));
}

bool check_grad(sdp::sdp_params const& params, bool debug) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the implementation details, it returns True when

  • Grad mode is not enabled
  • All input tensors do not require a gradient
  • Not Group Query Attention and the attention mask do not require a gradient

@LuFinch , is my understanding correct? If so, I would suggest refining the name of check_grad a little bit. Something could be like is_onednn_attention_backward_supported to illustrate your idea clearly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be used to check the grad requirements of inputs to determine whether they are suitable for supporting overrideable SDPA on XPU in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As guangye saying, it is use to determine whether use overrideable SDPA. If return True, then it can use overrideable SDPA.

In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't output grad for attn_mask.

Hence this function means:

  • If Grad mode is not enabled, we can use overrideable SDPA to run OneDNN SDPA inference forward graph.
  • Grad mode is enabled but none of q/k/v needs grad, we can use overrideable SDPA to run OneDNN SDPA inference forward graph.
  • If we need to compute grad, it is not GQA and attn_mask don't require gard, then we can use overrideable SDPA to run OneDNN SDPA training forward graph.
  • Otherwise, it should fallback to MATH backend.

auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(-3);
bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads;
if (debug && is_gqa)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it has been gqa, why does this function return false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask.


bool attn_mask_needs_grad =
params.attn_mask.has_value() && params.attn_mask.value().requires_grad();
if (debug && attn_mask_needs_grad) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask.

auto grad_attn_bias = attn_bias_opt.has_value()
? at::empty_like(attn_bias_opt.value())
: at::Tensor();
at::native::onednn::gpu_float_sdpa_backward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu_float_sdpa_backward has been defined? Does it mean the backward function only supports float?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It supports FP32/FP16/BF16. I directly copy this function name from SDPA inference. We could rename it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed as sdpa_backward

grad_out.dim() == 4 && out.dim() == 4 &&
grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) &&
grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {(B), H, T, K}");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of (B)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy from forward code. It just means batchsize. Already removed the bracket.

is_causal, logical_params);
auto i = logical_params.get_input();
auto o = logical_params.get_output();
auto compiled_partition = partition_.compile(i, o, eng);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable shadows a similar declaration at line 972. It's fine but not good. I recommend renaming this to avoid name shadowing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patient review. Renamed it.

if (is_causal) {
neg_inf = at::full(
{},
-INFINITY,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at::numeric_limits<>::lower_bound is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file already uses -std::numeric_limits<float>::infinity, I replaced -INFINITY with -std::numeric_limits<float>::infinity, too.

Comment on lines +1024 to +1036
inputs.reserve(l_inputs.size());
inputs.emplace_back(l_inputs[i++], eng, grad_out.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, query.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, key.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, value.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, out.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, logsumexp.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, softmax_scale.data_ptr());
if (neg_inf.has_value()) {
inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr());
}
if (attn_mask.has_value()) {
inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a macro to reduce the duplicated code. such as

#define ADD_INPUT(variable) \
inputs.emplace_back(l_inputs[i++], eng, variable.data_ptr())
ADD_INPUT(grad_out);
ADD_INPUT(query);
...
#undef ADD_INPUT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Done.

partition& find_or_create_backward_graph_partition(
bool is_causal,
const SDPABackwardLogicalParams& params) {
thread_local static PartitionCache cache;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
thread_local static PartitionCache cache;
thread_local PartitionCache cache;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

std::bitset<32> patternID;
if (dtype == data_type::f32) {
// bit 3 corresponds to float32 dtype
patternID.set(3, 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine. But I recommend using a name like kBitFloat32 instead of hardcoded number 3. Another one is that kBitFloat32 could be shared with find_or_create_graph_partition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Done.

at::Tensor reshaped_attention = attention_;
at::Tensor reshaped_logsumexp = logsumexp_.unsqueeze(-1);
at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor());
if (at::native::onednn::is_broadcast(reshaped_query)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this code change, sdpa will not support broadcast anymore. Is this a BC breaking, any impact for the old script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After offline discussion, I add back broadcast for QKV. However, the output tensor attention and logsumexp are allocated by us and they should not be broadcasted. Hence I don't add broadcast for these two tensors.

Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 13, 2025
@LuFinch
Copy link
Contributor Author

LuFinch commented Aug 13, 2025

@guangyey Thanks for your review. Eikan recommends to split this PR into two parts: one for onednn only and another for pytorch api change. This will make it easier to review for community. I will open new PRs.

@LuFinch LuFinch closed this Aug 20, 2025
@github-project-automation github-project-automation bot moved this from In Progress to Done in PyTorch Intel Aug 20, 2025
pytorchmergebot pushed a commit that referenced this pull request Sep 8, 2025
)

This PR is the first split PR of #156272, only contains the OneDNN code. Please help review.

Pending on OneDNN v3.9 commit update. Don't merge.

Pull Request resolved: #161058
Approved by: https://github.com/guangyey, https://github.com/EikanWang
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…rch#161058)

This PR is the first split PR of pytorch#156272, only contains the OneDNN code. Please help review.

Pending on OneDNN v3.9 commit update. Don't merge.

Pull Request resolved: pytorch#161058
Approved by: https://github.com/guangyey, https://github.com/EikanWang
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…rch#161058)

This PR is the first split PR of pytorch#156272, only contains the OneDNN code. Please help review.

Pending on OneDNN v3.9 commit update. Don't merge.

Pull Request resolved: pytorch#161058
Approved by: https://github.com/guangyey, https://github.com/EikanWang
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…rch#161058)

This PR is the first split PR of pytorch#156272, only contains the OneDNN code. Please help review.

Pending on OneDNN v3.9 commit update. Don't merge.

Pull Request resolved: pytorch#161058
Approved by: https://github.com/guangyey, https://github.com/EikanWang
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…rch#161058)

This PR is the first split PR of pytorch#156272, only contains the OneDNN code. Please help review.

Pending on OneDNN v3.9 commit update. Don't merge.

Pull Request resolved: pytorch#161058
Approved by: https://github.com/guangyey, https://github.com/EikanWang
pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2025
@pytorch pytorch deleted a comment from pytorch-bot bot Oct 31, 2025
@pytorch pytorch deleted a comment from pytorchmergebot Oct 31, 2025
@pytorch pytorch deleted a comment from pytorchmergebot Oct 31, 2025
BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
@LuFinch LuFinch deleted the lfq/sdpa_traning branch January 28, 2026 05:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/xpu Run XPU CI tasks module: cpu CPU specific problem (e.g., perf, algorithm) module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration module: xpu Intel XPU related issues open source release notes: inductor (aoti) topic: not user facing topic category

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

6 participants