[ROCm][CK] Enable variable-length attention support for CK SDPA backend#172246
[ROCm][CK] Enable variable-length attention support for CK SDPA backend#172246chinmaydk99 wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172246
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit cf43b84 with merge base cd17970 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
46520fb to
5e2a3d7
Compare
5e2a3d7 to
f28bff7
Compare
fef93c7 to
1d8f734
Compare
4b4689f to
625fef3
Compare
test/test_varlen_attention.py
Outdated
| TEST_WITH_ROCM, "varlen attention w/ sliding window not supported on ROCm" | ||
| ) | ||
| @parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_custom_op_compliance(self, device, dtype): |
There was a problem hiding this comment.
I notice this test doesn't have the "sdpa_backend" parametrization like the other ones? is there a reason for that?
There was a problem hiding this comment.
Had temporarily removed the sdpa backend parametrization since we were awaiting the AOTriton PR to merge. Will add it here as well as in the next test.
6054e49 to
ea61d0e
Compare
ea61d0e to
6a7430e
Compare
6a7430e to
89618a1
Compare
…sion The internal fbsource build uses a pinned CK version that does not have these struct fields which were added for a newer CK in pytorch#172246: - block_scale_seqstart_{q,k}_ptr - nhead_stride_{q,k,v}_descale - batch_stride_{q,k,v}_descale - block_scale_size_{q,kv} These extra fields cause type mismatch build errors: - nullptr for ck_tile::index_t (int) fields - int-to-float and float-to-int narrowing from shifted field positions Remove the 10 extra fields from fmha_fwd_args initialization in both mha_fwd_ck.hip and mha_varlen_fwd_ck.hip. Forward fix for pytorch#172246.
|
@pytorchbot revert -m "sorry, failing builds with older HIP versions, see below" -c ghirst cc @malfet |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot revert -m "sorry, failing builds with older HIP versions, see below" -c ghfirst |
|
@pytorchbot successfully started a revert job. Check the current status here. |
Reverting PR 172246 failedReason: Command Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot revert -m 'sorry, failing builds with older HIP versions, see below' -c ghfirst |
|
@pytorchbot successfully started a revert job. Check the current status here. |
…PA backend (#172246)" This reverts commit e1b7388. Reverted #172246 on behalf of https://github.com/huydhn due to sorry, failing builds with older HIP versions, see below ([comment](#172246 (comment)))
|
@chinmaydk99 your PR has been successfully reverted. |
This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.
|
@huydhn @wdvr |
FYI @liangel-02 |
|
I'm not sure who is the local expert on ROCm from Meta that could help take a look. @malfet Do you have a poc in mind that I can reach out? |
Sorry, no I do not. Is there a versioning in CK one can use to guard this newer API from an old one? |
Summary
Enables variable-length (varlen) attention support for the Composable Kernel (CK) SDPA backend on ROCm.
Changes
Forward pass (
mha_varlen_fwd_ck.hip){batch_size, num_heads, max_seqlen_q}to 2D{num_heads, total_q}to match CK group mode expectationnhead_stride_lse: usestride(0)instead ofstride(1)for 2D layoutbatch_stride_lse: set to0(no batch dimension in group mode LSE)min_seqlen_q: changed from-1to1(valid minimum for kernel dispatch)Backward pass (
mha_varlen_bwd_ck.hip)if (is_dropout)to avoid dtype mismatch when dropout is disabledTest infrastructure (
test/test_varlen_attention.py)sdpa_backendparametrization to test bothaotritonandckbackends on ROCm["aotriton", "ck"]when CK is available,["aotriton"]otherwisepreferred_rocm_fa_library()to switch backends per testPlatform detection
PLATFORM_SUPPORTS_CK_SDPAintorch/testing/_internal/common_cuda.py_is_ck_sdpa_available()intorch/csrc/Module.cppfor runtime CK availability checkFake tensor implementation (
torch/nn/attention/varlen.py)_varlen_attn_faketo use standard[num_heads, total_q]logsumexp format (matches aotriton/CUDA)Dependencies: CK submodule bump
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang