Skip to content

Update SDPA flash attention API#4065

Merged
Priya2698 merged 11 commits intomainfrom
pm/sdpa_update_API
Mar 17, 2025
Merged

Update SDPA flash attention API#4065
Priya2698 merged 11 commits intomainfrom
pm/sdpa_update_API

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Mar 12, 2025

pytorch/pytorch#146372 changes the flash attention API.

// [Note] BC breaking change to flash seed/offset
// Previously: Used separate tensors for philox_seed and philox_offset, sometimes on CPU, sometimes on CUDA
// FlashAttention change: Now uses a single uint64_t[2] tensor on device containing both seed and offset
// Implementation: Renamed "seed" → "rng_state" (contains both seed+offset) and "offset" → "_unused"

In nvfuser API, SdpaFwdOp now returns a rng_state and SdpaBwdOp expects a rng_state. _unused is ignored.

In nvfuser API, based on torch version, philox_seed is now uint64_t[2] and philox offset is a empty uint64_t tensor on device. I chose to keep the same semantics as PyTorch since it avoids packing/unpacking PyTorch outputs when testing if I were to split the rng_state output of PyTorch into philox_seed and philox_offset. The latter approach would keep the dimensions of both tensors the same, but since we're changing device and dtype, they will need to be created distinctly for the two cases. This also allows us to switch to the single-parameter version in the future more easily if desired.

@github-actions
Copy link

github-actions bot commented Mar 12, 2025

Review updated until commit 0a126c0

Description

  • Updated SDPA flash attention API to handle rng_state changes.

  • Introduced createSdpaRngTvs and createSdpaRngTensors for rng state management.

  • Adjusted TensorView and Tensor definitions based on torch version.

  • Updated tests to accommodate new rng state handling.


Changes walkthrough 📝

Relevant files
Enhancement
14 files
transformer.cpp
Updated SDPA rng state creation                                                   
+1/-2     
logical_domain_map.cpp
Added checks for philox_seed in domain mapping                     
+7/-1     
composite.cpp
Updated SDPA forward and backward operations for rng_state
+15/-6   
fusion_definition.cpp
Added support for DataType::UInt64                                             
+2/-0     
python_bindings.cpp
Added DataType::UInt64 to Python bindings                               
+6/-1     
multidevice_transformer.cpp
Updated SDPA rng state creation in tests                                 
+1/-2     
test_allocation_order_inference.cpp
Updated SDPA rng state creation in tests                                 
+1/-2     
test_multidevice_transformer.cpp
Updated SDPA rng state creation in tests                                 
+1/-2     
test_sdpa_node.cpp
Updated SDPA mapping checks to exclude rng_state                 
+35/-37 
utils.cpp
Added functions for creating SDPA rng states                         
+35/-0   
test_multidevice.py
Updated SDPA rng state handling in Python tests                   
+8/-7     
test_sdpa.py
Updated SDPA rng state handling in Python tests                   
+2/-13   
utils.py
Added functions for defining and creating SDPA rng states
+31/-1   
utils.h
Added declarations for SDPA rng state functions                   
+6/-0     
Documentation
1 files
internal_nodes.h
Updated comments to reflect rng_state changes                       
+15/-5   
Configuration changes
1 files
version.txt
Bumped version number                                                                       
+1/-1     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Issue

The code checks the torch version to determine the shape and dtype of philox_seed and philox_offset. However, it does not handle the case where the torch version is exactly 2.7.0. This could lead to unexpected behavior if the API changes in a way that is not backward compatible with the version check.

  TensorView* philox_seed = nullptr;
  TensorView* philox_offset = nullptr;
#if NVF_TORCH_VERSION_NO_LESS(2, 7, 0)
  // API changes in torch 2.7.0
  // The torch API returns philox_seed -> rng_state (uint64_t[2])
  // and philox_offset -> _unused (empty tensor)
  philox_seed = TensorViewBuilder()
                    .shape(std::vector<int64_t>{2})
                    .dtype(DataType::UInt64)
                    .build();
  philox_offset = TensorViewBuilder().dtype(DataType::UInt64).build();
#else
  // Scalar tensors of int64_t dtype.
  philox_seed = TensorViewBuilder().dtype(DataType::Int).build();
  philox_offset = TensorViewBuilder().dtype(DataType::Int).build();
  philox_seed->setCpuScalar(true);
  philox_offset->setCpuScalar(true);
#endif
Possible Issue

The function create_sdpa_rng_tensors uses torch.testing.make_tensor to generate random tensors for philox_seed and philox_offset. However, the shape and dtype of these tensors are determined by the torch version. If the torch version is 2.7.0, the shape of philox_seed is set to (2,), but the shape of philox_offset is set to (). This could lead to inconsistencies if the torch API changes in a way that is not backward compatible with the version check.

ref_philox_seed, ref_philox_offset = create_sdpa_rng_tensors()
_assert_shape_dtype(sdpa_seed, ref_philox_seed.shape, ref_philox_seed.dtype)
_assert_shape_dtype(sdpa_offset, ref_philox_offset.shape, ref_philox_offset.dtype)
Possible Issue

The function define_sdpa_rng_state uses the UPDATED_SDPA flag to determine the shape and dtype of philox_seed and philox_offset. However, it does not handle the case where the torch version is exactly 2.7.0. This could lead to unexpected behavior if the API changes in a way that is not backward compatible with the version check.

def define_sdpa_rng_state(fd: FusionDefinition) -> tuple[Tensor, Tensor]:
    dtype = DataType.UInt64 if UPDATED_SDPA else DataType.Int
    is_cpu = False if UPDATED_SDPA else True
    philox_shape = [2] if UPDATED_SDPA else []
    philox_seed = fd.define_tensor(
        shape=philox_shape,
        dtype=dtype,
        is_cpu=is_cpu,
    )
    philox_offset = fd.define_tensor(
        shape=[],
        dtype=dtype,
        is_cpu=is_cpu,
    )
    return philox_seed, philox_offset

@Priya2698
Copy link
Collaborator Author

!test

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- thanks for fixing the breakage!

As discussed offline, there's the alternative to

  1. change Sdpa*Op::evaluate without changing the fusion IR
  2. change the two TensorViews' (seed and offset) dtype from int64 to uint64
  3. change their devices from CPU to GPU.

That would make this PR much smaller and also minimize the impact on Thunder.

@Priya2698
Copy link
Collaborator Author

Another way is:
Keep philox_seed and philox_offset separate but they will still need to be on device and of uint64. This will still require changes in tests when creating these as tensorviews. Additional changes will be required when using outputs from ATen to test nvfuser (unpack the rng_state for sdpa bwd). Some changes will be avoided since the number of arguments are the same.

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this PR changes us fully from the two-parameter version to the one-parameter version. Will this still work with older PyTorch or does this mean we are bumping our required torch version?

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Mar 13, 2025

The latest stable release 2.6.0 was in Jan 2024 and would not have this change. Thunder requires nvfuser to support both the nightly and the latest stable release.
The only way I see is to use NVF_TORCH_VERSION directives across tests.

For supporting both versions, it would also be simpler to keep philox_seed and philox_offset separate but with changed dtype and int64_t. This will be required everywhere we create seed/offset variables or torch API to generate reference inputs/outputs (can probably be a helper function).

Is there any other approach that we could try for supporting two API variants?

…ing in their device, dtype, and shape between versions of pytorch.
@Priya2698 Priya2698 force-pushed the pm/sdpa_update_API branch from 686ddd9 to 480ae2e Compare March 14, 2025 05:29
@Priya2698
Copy link
Collaborator Author

What is the best way to determine the appropriate torch versions? The latest stable release is 2.6.0 from January. The current version.txt file lists 2.8.0.

@Priya2698
Copy link
Collaborator Author

!test

@Priya2698
Copy link
Collaborator Author

!test

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Mar 14, 2025

Except for Thunder tests (which are expected), the errors seem unrelated.

return inner_fn


UPDATED_SDPA = torch.__version__ > "2.6.0"
Copy link
Collaborator Author

@Priya2698 Priya2698 Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be updated to str(torch.version) > "2.7.0" or using packaging.version.

The changes to SDPA were made between versions 2.7.0 and 2.8.0 based on version.txt in PyTorch. So I am using 2.7.0.

@Priya2698
Copy link
Collaborator Author

!test

@Priya2698 Priya2698 requested a review from wujingyue March 14, 2025 22:25
Priya2698 and others added 6 commits March 14, 2025 16:53
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
@Priya2698
Copy link
Collaborator Author

!test

@Priya2698
Copy link
Collaborator Author

Not sure how to set up a CI pipeline for torch 2.6.0 stable release. I have verified the affected tests locally.

// The torch API returns philox_seed -> rng_state (uint64_t[2])
// and philox_offset -> _unused (empty tensor)
philox_seed = TensorViewBuilder()
.shape(std::vector<int64_t>{2})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: do we need to be explicit with std::vector<int64_t>{2}? I think we usually just do shape({2}).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Surprisingly, I was getting an ambiguous construct compile error. I did not dig into it more.

@Priya2698
Copy link
Collaborator Author

!build

@Priya2698 Priya2698 merged commit edefcad into main Mar 17, 2025
15 of 16 checks passed
@Priya2698 Priya2698 deleted the pm/sdpa_update_API branch March 17, 2025 21:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants