Skip to content

[ROCm][CK] Enable variable-length attention support for CK SDPA backend#172246

Open
chinmaydk99 wants to merge 3 commits intopytorch:mainfrom
chinmaydk99:ck-varlen-attn
Open

[ROCm][CK] Enable variable-length attention support for CK SDPA backend#172246
chinmaydk99 wants to merge 3 commits intopytorch:mainfrom
chinmaydk99:ck-varlen-attn

Conversation

@chinmaydk99
Copy link
Contributor

@chinmaydk99 chinmaydk99 commented Jan 12, 2026

Summary

Enables variable-length (varlen) attention support for the Composable Kernel (CK) SDPA backend on ROCm.

Changes

Forward pass (mha_varlen_fwd_ck.hip)

  • Fixed LSE tensor allocation: changed from 3D {batch_size, num_heads, max_seqlen_q} to 2D {num_heads, total_q} to match CK group mode expectation
  • Fixed nhead_stride_lse: use stride(0) instead of stride(1) for 2D layout
  • Fixed batch_stride_lse: set to 0 (no batch dimension in group mode LSE)
  • Fixed min_seqlen_q: changed from -1 to 1 (valid minimum for kernel dispatch)

Backward pass (mha_varlen_bwd_ck.hip)

  • Fixed philox seed/offset access: guarded with if (is_dropout) to avoid dtype mismatch when dropout is disabled

Test infrastructure (test/test_varlen_attention.py)

  • Added sdpa_backend parametrization to test both aotriton and ck backends on ROCm
  • Backend selection: ["aotriton", "ck"] when CK is available, ["aotriton"] otherwise
  • Uses preferred_rocm_fa_library() to switch backends per test

Platform detection

  • Added PLATFORM_SUPPORTS_CK_SDPA in torch/testing/_internal/common_cuda.py
  • Added _is_ck_sdpa_available() in torch/csrc/Module.cpp for runtime CK availability check

Fake tensor implementation (torch/nn/attention/varlen.py)

  • Updated _varlen_attn_fake to use standard [num_heads, total_q] logsumexp format (matches aotriton/CUDA)

Dependencies: CK submodule bump

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172246

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit cf43b84 with merge base cd17970 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Jan 12, 2026
@chinmaydk99 chinmaydk99 force-pushed the ck-varlen-attn branch 3 times, most recently from 46520fb to 5e2a3d7 Compare January 13, 2026 22:29
@pytorch-bot pytorch-bot bot added the release notes: releng release notes category label Jan 21, 2026
@chinmaydk99 chinmaydk99 force-pushed the ck-varlen-attn branch 2 times, most recently from fef93c7 to 1d8f734 Compare January 21, 2026 03:12
@chinmaydk99 chinmaydk99 force-pushed the ck-varlen-attn branch 3 times, most recently from 4b4689f to 625fef3 Compare February 3, 2026 23:07
TEST_WITH_ROCM, "varlen attention w/ sliding window not supported on ROCm"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_compliance(self, device, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice this test doesn't have the "sdpa_backend" parametrization like the other ones? is there a reason for that?

Copy link
Contributor Author

@chinmaydk99 chinmaydk99 Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had temporarily removed the sdpa backend parametrization since we were awaiting the AOTriton PR to merge. Will add it here as well as in the next test.

@chinmaydk99 chinmaydk99 force-pushed the ck-varlen-attn branch 2 times, most recently from 6054e49 to ea61d0e Compare February 10, 2026 18:03
@jeffdaily jeffdaily added the ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 label Feb 17, 2026
@pytorch-bot pytorch-bot bot removed the ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 label Feb 22, 2026
@jithunnair-amd jithunnair-amd added ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi200 Trigger "default" config CI on ROCm MI200 labels Feb 22, 2026
@pytorch-bot pytorch-bot bot removed ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/rocm-mi200 Trigger "default" config CI on ROCm MI200 labels Feb 22, 2026
@jithunnair-amd jithunnair-amd added ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi200 Trigger "default" config CI on ROCm MI200 labels Feb 22, 2026
wdvr added a commit to wdvr/pytorch that referenced this pull request Mar 9, 2026
…sion

The internal fbsource build uses a pinned CK version that does not have
these struct fields which were added for a newer CK in pytorch#172246:
- block_scale_seqstart_{q,k}_ptr
- nhead_stride_{q,k,v}_descale
- batch_stride_{q,k,v}_descale
- block_scale_size_{q,kv}

These extra fields cause type mismatch build errors:
- nullptr for ck_tile::index_t (int) fields
- int-to-float and float-to-int narrowing from shifted field positions

Remove the 10 extra fields from fmha_fwd_args initialization in both
mha_fwd_ck.hip and mha_varlen_fwd_ck.hip.

Forward fix for pytorch#172246.
@wdvr
Copy link
Contributor

wdvr commented Mar 9, 2026

@pytorchbot revert -m "sorry, failing builds with older HIP versions, see below" -c ghirst

Stderr:
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:116:26: error: cannot initialize a member subobject of type 'ck_tile::index_t' (aka 'int') with an rvalue of type 'std::nullptr_t'
  116 |                          nullptr, // block_scale_seqstart_k_ptr
      |                          ^~~~~~~
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:117:26: error: cannot initialize a member subobject of type 'ck_tile::index_t' (aka 'int') with an rvalue of type 'std::nullptr_t'
  117 |                          nullptr, // sink_ptr
      |                          ^~~~~~~
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:124:26: error: non-constant-expression cannot be narrowed from type 'int' to 'float' in initializer list [-Wc++11-narrowing]
  124 |                          h,                                 // nhead
      |                          ^
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:124:26: note: insert an explicit cast to silence this issue
  124 |                          h,                                 // nhead
      |                          ^
      |                          static_cast<float>( )
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:125:26: error: non-constant-expression cannot be narrowed from type 'int' to 'float' in initializer list [-Wc++11-narrowing]
  125 |                          h_k,                               // nhead_k
      |                          ^~~
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:125:26: note: insert an explicit cast to silence this issue
  125 |                          h_k,                               // nhead_k
      |                          ^~~
      |                          static_cast<float>( )
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:126:26: error: type 'float' cannot be narrowed to 'ck_tile::index_t' (aka 'int') in initializer list [-Wc++11-narrowing]
  126 |                          softmax_scale,                     // scale_s
      |                          ^~~~~~~~~~~~~
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:126:26: note: insert an explicit cast to silence this issue
  126 |                          softmax_scale,                     // scale_s
      |                          ^~~~~~~~~~~~~
      |                          static_cast<index_t>( )
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:127:26: error: type 'float' cannot be narrowed to 'ck_tile::index_t' (aka 'int') in initializer list [-Wc++11-narrowing]
  127 |                          0.0f,                              // logits_soft_cap
      |                          ^~~~
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:127:26: note: insert an explicit cast to silence this issue
  127 |                          0.0f,                              // logits_soft_cap
      |                          ^~~~
      |                          static_cast<index_t>( )
fbcode/caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip:153:26: error: no viable conversion from 'int' to 'std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void *, const void *>>' (aka 'variant<pair<unsigned long, unsigned long>, pair<const void *, const void *>>')
  153 |                          0, // batch_stride_v_descale
      |                          ^
fbcode/third-party-buck/platform010/build/libgcc/include/c++/trunk/variant:1386:7: note: candidate constructor not viable: no known conversion from 'int' to 'const variant<pair<unsigned long, unsigned long>, pair<const void *, const void *>> &' for 1st argument
 1386 |       variant(const variant& __rhs) = default;
      |       ^       ~~~~~~~~~~~~~~~~~~~~
fbcode/third-party-buck/platform010/build/libgcc/include/c++/trunk/variant:1387:7: note: candidate constructor not viable: no known conversion from 'int' to 'variant<pair<unsigned long, unsigned long>, pair<const void *, const void *>> &&' for 1st argument
 1387 |       variant(variant&&) = default;
      |       ^       ~~~~~~~~~
fbcode/third-party-buck/platform010/build/libgcc/include/c++/trunk/variant:1399:2: note: candidate template ignored: requirement '18446744073709551615UL < sizeof...(_Types)' was not satisfied [with _Tp = int, $1 = enable_if_t<sizeof...(_Types) != 0>, $2 = enable_if_t<__not_in_place_tag<int>>]
 1399 |         variant(_Tp&& __t)
      |         ^
fbcode/third-party-buck/platform010/build/libgcc/include/c++/trunk/variant:1409:2: note: explicit constructor is not a candidate
 1409 |         variant(in_place_type_t<_Tp>, _Args&&... __args)
      |         ^
fbcode/third-party-buck/platform010/build/libgcc/include/c++/trunk/variant:1429:2: note: explicit constructor is not a candidate
 1429 |         variant(in_place_index_t<_Np>, _Args&&... __args)
      |         ^
7 errors generated when compiling for gfx942.

cc @malfet

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 9, 2026

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: argument -c/--classification: invalid choice: 'ghirst' (choose from 'nosignal', 'ignoredsignal', 'landrace', 'weird', 'ghfirst', 'autorevert')

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst,autorevert}

Try @pytorchbot --help for more info.

@wdvr
Copy link
Contributor

wdvr commented Mar 9, 2026

@pytorchbot revert -m "sorry, failing builds with older HIP versions, see below" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 172246 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit e1b7388686956c7a7ec5b04e8b393aad67d3e8fb returned non-zero exit code 1

Auto-merging test/test_varlen_attention.py
CONFLICT (content): Merge conflict in test/test_varlen_attention.py
Auto-merging torch/_C/__init__.pyi.in
Auto-merging torch/backends/cuda/__init__.py
Auto-merging torch/nn/attention/varlen.py
error: could not revert e1b73886869... [ROCm][CK] Enable variable-length attention support for CK SDPA backend (#172246)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@huydhn
Copy link
Contributor

huydhn commented Mar 10, 2026

@pytorchbot revert -m 'sorry, failing builds with older HIP versions, see below' -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
…PA backend (#172246)"

This reverts commit e1b7388.

Reverted #172246 on behalf of https://github.com/huydhn due to sorry, failing builds with older HIP versions, see below ([comment](#172246 (comment)))
@pytorchmergebot
Copy link
Collaborator

@chinmaydk99 your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Mar 10, 2026
@pytorch-bot pytorch-bot bot dismissed jeffdaily’s stale review March 10, 2026 19:13

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@alugorey
Copy link
Contributor

@huydhn @wdvr
These failures are due to a mismatch in CK versions which we are updating in this PR. The version of CK we were using in 6.4 differs greatly against the latest and greatest version we are attempting to update to. The api fundamentally changes. If internal 6.4 builds are pinned at an old CK, and that is running in your CI and gating this from being merged, I'm not sure how we are ever supposed to be able to update CK in the latest version of pytorch. please advise on what we should do here.

@alugorey
Copy link
Contributor

@huydhn @wdvr These failures are due to a mismatch in CK versions which we are updating in this PR. The version of CK we were using in 6.4 differs greatly against the latest and greatest version we are attempting to update to. The api fundamentally changes. If internal 6.4 builds are pinned at an old CK, and that is running in your CI and gating this from being merged, I'm not sure how we are ever supposed to be able to update CK in the latest version of pytorch. please advise on what we should do here.

FYI @liangel-02

@huydhn
Copy link
Contributor

huydhn commented Mar 10, 2026

I'm not sure who is the local expert on ROCm from Meta that could help take a look. @malfet Do you have a poc in mind that I can reach out?

@malfet
Copy link
Contributor

malfet commented Mar 18, 2026

I'm not sure who is the local expert on ROCm from Meta that could help take a look. @malfet Do you have a poc in mind that I can reach out?

Sorry, no I do not. Is there a versioning in CK one can use to guard this newer API from an old one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor-rocm-mi300 Trigger "inductor" config CI on ROCm MI300/MI325 ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source release notes: releng release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants