[Intel GPU] Enable backward for SDPA XPU [WIP]#156272
[Intel GPU] Enable backward for SDPA XPU [WIP]#156272LuFinch wants to merge 14 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156272
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New Failures, 1 Unrelated FailureAs of commit acbac05 with merge base 908c5cc ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "topic: not user facing" |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
7014f89 to
232387d
Compare
40bafeb to
435e14a
Compare
672a28f to
4d05632
Compare
4d05632 to
6de83d4
Compare
be64cfc to
c8d7c3b
Compare
| tags: nondeterministic_seeded | ||
|
|
||
| - func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) | ||
| - func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, bool compute_log_sumexp=False, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) |
There was a problem hiding this comment.
@LuFinch , why do we need to add comput_log_sumexp? It breaks the ABI backward compatibility.
There was a problem hiding this comment.
Ideally, we can check input tensors' attr like compute_logsumexp = query.requires_grad() || key.requires_grad() || value.requires_grad() to decide whether compute logsumexp. And this checking works in eager mode.
However, in torch.compile mode, the input tensors query/key/value require_grad()==True at the beginning would become require_grad()==False in the op after aot_autograd in some models. Hence, it needs a bool flag to indicate this op should compute logsumexp. I am not an expert of aot_autograd and not sure why it acts like this. But cudnn and efficient attention also has this parameter. I guess they meet the same issue, otherwise they should be able to move this check into op.
pytorch/aten/src/ATen/native/transformers/attention.cpp
Lines 739 to 742 in 78d7f0c
pytorch/aten/src/ATen/native/transformers/attention.cpp
Lines 761 to 769 in 78d7f0c
There was a problem hiding this comment.
@LuFinch , why do we need to add
comput_log_sumexp? It breaks the ABI backward compatibility.
By default comput_log_sumexp=False, this should not break API-level BC, right?
There was a problem hiding this comment.
Eikan recommends me to move this argument as the last argument, then it will not break BC.
There was a problem hiding this comment.
Already moved this argument as the last argument.
| (attn_mask.has_value() && attn_mask.value().requires_grad())); | ||
| } | ||
|
|
||
| bool check_grad(sdp::sdp_params const& params, bool debug) { |
There was a problem hiding this comment.
According to the implementation details, it returns True when
- Grad mode is not enabled
- All input tensors do not require a gradient
- Not Group Query Attention and the attention mask do not require a gradient
@LuFinch , is my understanding correct? If so, I would suggest refining the name of check_grad a little bit. Something could be like is_onednn_attention_backward_supported to illustrate your idea clearly.
There was a problem hiding this comment.
This function should be used to check the grad requirements of inputs to determine whether they are suitable for supporting overrideable SDPA on XPU in the future.
There was a problem hiding this comment.
As guangye saying, it is use to determine whether use overrideable SDPA. If return True, then it can use overrideable SDPA.
In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't output grad for attn_mask.
Hence this function means:
- If Grad mode is not enabled, we can use overrideable SDPA to run OneDNN SDPA inference forward graph.
- Grad mode is enabled but none of q/k/v needs grad, we can use overrideable SDPA to run OneDNN SDPA inference forward graph.
- If we need to compute grad, it is not GQA and attn_mask don't require gard, then we can use overrideable SDPA to run OneDNN SDPA training forward graph.
- Otherwise, it should fallback to MATH backend.
| auto k_num_heads = params.key.sym_size(-3); | ||
| auto v_num_heads = params.value.sym_size(-3); | ||
| bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads; | ||
| if (debug && is_gqa) |
There was a problem hiding this comment.
Since it has been gqa, why does this function return false?
There was a problem hiding this comment.
In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask.
|
|
||
| bool attn_mask_needs_grad = | ||
| params.attn_mask.has_value() && params.attn_mask.value().requires_grad(); | ||
| if (debug && attn_mask_needs_grad) { |
There was a problem hiding this comment.
In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask.
| auto grad_attn_bias = attn_bias_opt.has_value() | ||
| ? at::empty_like(attn_bias_opt.value()) | ||
| : at::Tensor(); | ||
| at::native::onednn::gpu_float_sdpa_backward( |
There was a problem hiding this comment.
gpu_float_sdpa_backward has been defined? Does it mean the backward function only supports float?
There was a problem hiding this comment.
I have the same question.
There was a problem hiding this comment.
It supports FP32/FP16/BF16. I directly copy this function name from SDPA inference. We could rename it.
There was a problem hiding this comment.
renamed as sdpa_backward
| grad_out.dim() == 4 && out.dim() == 4 && | ||
| grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) && | ||
| grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3), | ||
| "scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {(B), H, T, K}"); |
There was a problem hiding this comment.
What's the meaning of (B)?
There was a problem hiding this comment.
Copy from forward code. It just means batchsize. Already removed the bracket.
| is_causal, logical_params); | ||
| auto i = logical_params.get_input(); | ||
| auto o = logical_params.get_output(); | ||
| auto compiled_partition = partition_.compile(i, o, eng); |
There was a problem hiding this comment.
This variable shadows a similar declaration at line 972. It's fine but not good. I recommend renaming this to avoid name shadowing.
There was a problem hiding this comment.
Thanks for your patient review. Renamed it.
| if (is_causal) { | ||
| neg_inf = at::full( | ||
| {}, | ||
| -INFINITY, |
There was a problem hiding this comment.
at::numeric_limits<>::lower_bound is better.
There was a problem hiding this comment.
This file already uses -std::numeric_limits<float>::infinity, I replaced -INFINITY with -std::numeric_limits<float>::infinity, too.
| inputs.reserve(l_inputs.size()); | ||
| inputs.emplace_back(l_inputs[i++], eng, grad_out.data_ptr()); | ||
| inputs.emplace_back(l_inputs[i++], eng, query.data_ptr()); | ||
| inputs.emplace_back(l_inputs[i++], eng, key.data_ptr()); | ||
| inputs.emplace_back(l_inputs[i++], eng, value.data_ptr()); | ||
| inputs.emplace_back(l_inputs[i++], eng, out.data_ptr()); | ||
| inputs.emplace_back(l_inputs[i++], eng, logsumexp.data_ptr()); | ||
| inputs.emplace_back(l_inputs[i++], eng, softmax_scale.data_ptr()); | ||
| if (neg_inf.has_value()) { | ||
| inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr()); | ||
| } | ||
| if (attn_mask.has_value()) { | ||
| inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr()); |
There was a problem hiding this comment.
Use a macro to reduce the duplicated code. such as
#define ADD_INPUT(variable) \
inputs.emplace_back(l_inputs[i++], eng, variable.data_ptr())
ADD_INPUT(grad_out);
ADD_INPUT(query);
...
#undef ADD_INPUT| partition& find_or_create_backward_graph_partition( | ||
| bool is_causal, | ||
| const SDPABackwardLogicalParams& params) { | ||
| thread_local static PartitionCache cache; |
There was a problem hiding this comment.
| thread_local static PartitionCache cache; | |
| thread_local PartitionCache cache; |
| std::bitset<32> patternID; | ||
| if (dtype == data_type::f32) { | ||
| // bit 3 corresponds to float32 dtype | ||
| patternID.set(3, 1); |
There was a problem hiding this comment.
This is fine. But I recommend using a name like kBitFloat32 instead of hardcoded number 3. Another one is that kBitFloat32 could be shared with find_or_create_graph_partition
| at::Tensor reshaped_attention = attention_; | ||
| at::Tensor reshaped_logsumexp = logsumexp_.unsqueeze(-1); | ||
| at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor()); | ||
| if (at::native::onednn::is_broadcast(reshaped_query)) { |
There was a problem hiding this comment.
With this code change, sdpa will not support broadcast anymore. Is this a BC breaking, any impact for the old script.
There was a problem hiding this comment.
After offline discussion, I add back broadcast for QKV. However, the output tensor attention and logsumexp are allocated by us and they should not be broadcasted. Hence I don't add broadcast for these two tensors.
|
@guangyey Thanks for your review. Eikan recommends to split this PR into two parts: one for onednn only and another for pytorch api change. This will make it easier to review for community. I will open new PRs. |
) This PR is the first split PR of #156272, only contains the OneDNN code. Please help review. Pending on OneDNN v3.9 commit update. Don't merge. Pull Request resolved: #161058 Approved by: https://github.com/guangyey, https://github.com/EikanWang
…rch#161058) This PR is the first split PR of pytorch#156272, only contains the OneDNN code. Please help review. Pending on OneDNN v3.9 commit update. Don't merge. Pull Request resolved: pytorch#161058 Approved by: https://github.com/guangyey, https://github.com/EikanWang
…rch#161058) This PR is the first split PR of pytorch#156272, only contains the OneDNN code. Please help review. Pending on OneDNN v3.9 commit update. Don't merge. Pull Request resolved: pytorch#161058 Approved by: https://github.com/guangyey, https://github.com/EikanWang
…rch#161058) This PR is the first split PR of pytorch#156272, only contains the OneDNN code. Please help review. Pending on OneDNN v3.9 commit update. Don't merge. Pull Request resolved: pytorch#161058 Approved by: https://github.com/guangyey, https://github.com/EikanWang
…rch#161058) This PR is the first split PR of pytorch#156272, only contains the OneDNN code. Please help review. Pending on OneDNN v3.9 commit update. Don't merge. Pull Request resolved: pytorch#161058 Approved by: https://github.com/guangyey, https://github.com/EikanWang
…PU OVERRIDEABLE Backend (#162454) This is the second PR split from #156272 Pull Request resolved: #162454 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/drisspg
…PU OVERRIDEABLE Backend (#162454) This is the second PR split from #156272 Pull Request resolved: #162454 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/drisspg
…PU OVERRIDEABLE Backend (pytorch#162454) This is the second PR split from pytorch#156272 Pull Request resolved: pytorch#162454 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/drisspg
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @EikanWang @fengyuan14 @guangyey