Skip to content

[RFC] Scaled Dot Product Attention API Changes #110681

@drisspg

Description

@drisspg

Updated SDPA API

Authors:

Summary

In order for users to more easily manage the complexity handling of various bias formats we would like to expose the ability to pass in AttnBias derived classes to SDPA and have it dispatch to the most optimal kernel. This mechanisms also enable users to extend the behavior of SDPA with their own AttnBias flavors.

Motivation

Managing various bias types and their specializations in SDPA can become intricate for users. By enabling them to pass in custom bias types, we aim to streamline this process. Furthermore, this change would empower features like the score_mod API. Such features have shown considerable performance improvements when used in conjunction with torch.compile and SDPA, especially for specific mask types.

As well this would aid in enabling PyTorch to upgrade its FlashAttention kernel to the newest Implementation. There is currently some limitations regarding BC concerns, see: #108108.
And user motivation for this change: #110144

Proposed Implementation:

The foundational changes to the C++ core are outlined in this PR: #110399.
A prototype implementation can be found in this GitHub repository: drisspg/transformer_nuggets#5. This Prototype implementation would be incorporated into the existing top level nn.functional version of SDPA.

Currently, the interface for invoking the scaled_dot_product_attention looks like:

scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor

These changes would update the signature of sdpa to instead of accepting an attn_mask: Tensor accept a derived class from

class AttnBias(ABC):
    """Abstract base class for attention biases"""

    @abstractmethod
    def materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
        raise NotImplementedError("This is an abstract base class")

    @abstractmethod
    def needs_materialization(self) -> bool:
        raise NotImplementedError("This is an abstract base class")

    @staticmethod
    @abstractmethod
    def dispatch(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_bias: "AttnBias",
        causal: bool,
        scale: Optional[float],
        dropout_p: float,
    ) -> torch.Tensor:
        raise NotImplementedError("This is an abstract base class")

To ensure that existing functionality remains uninterrupted, we would also introduce popular derived classes, such as:

class TensorBias(AttnBias):
    """A bias that is a tensor"""

    def __init__(self, bias: torch.Tensor):
        self.bias = bias

    def materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
        return self.bias

    def needs_materialization(self) -> bool:
        return True

    @staticmethod
    def dispatch(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_bias: "TensorBias",
        causal: bool,
        scale: Optional[float],
        dropout_p: float,
    ) -> torch.Tensor:
        raise NotImplementedError(
            "TensorBias requires materialization, so this should never be called!"
        )

    def __repr__(self) -> str:
        return f"TensorBias(bias={self.bias})"
    	    

CausalBias Variants

An example that would utilize all features of the AttnBias class can be found Here

Changes to torch.nn.functional.scaled_dot_product_attention()

A quick summary along with implementation can be found below:

  • If the passed attn_mask is a tensor, the old method of computation is followed, along with a warning.
  • If no attn_mask is provided, the function computes without any mask.
  • If an attn_mask derived from AttnBias is provided and requires materialization (like TensorBias), the function materializes it and then computes.
  • If an attn_mask derived from AttnBias doesn't need materialization, it uses the dispatch method of that mask class to handle the computation.
def sdpa_prototype(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: Optional[AttnBias] = None,
    causal: bool = False,
    scale: Optional[float] = None,
    dropout_p: float = 0.0,
):
    assert attn_mask is None or isinstance(attn_mask, (AttnBias, torch.Tensor))

    if isinstance(attn_mask, torch.Tensor):
        warn("Passing a tensor as an attn_mask is deprecated. Please use TensorBias instead.")
        return scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, causal, scale)

    if attn_mask is None or attn_mask.needs_materialization():
        return scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=attn_mask.materialize(query.device) if attn_mask else None,
            dropout_p=dropout_p,
            is_causal=causal,
            scale=scale,
        )

    # After this point all AttnBias are required to have defined their own dispatching logic
    return attn_mask.dispatch(query, key, value, attn_mask, causal, scale, dropout_p)

Metrics

  • Performance Measure the performance improvement for popular additive biases like Alibi
  • Kernel Dispatch Overhead: Quantify the overhead introduced due to the new dispatch mechanism. Ideally, this overhead should be negligible.
  • API Adoption Rate: Track how many developers switch to using the new API over a specified time frame. A higher adoption rate signifies a positive reception.
  • User Feedback: Collect qualitative feedback from developers. This can provide insights into the usability and potential areas of improvement.

Drawbacks

Please consider:

  • This is not immediately user breaking but we would enable a deprecation cycle for passing in regular tensors.
  • I think this would meaningfully improve the UX but would also introduce some complexity
  • This would add a non trivial amount of new code to an existing HIGHLY exercised code path, and with that could introduce new bugs. However part of this changes enables much more of the logic to be completely in Python which users find much more understandable and easier to debug then C++.

Alternatives

  • Introducing Specialized Functions:

    • One approach is to introduce specialized functions similar to scaled_dot_product_attention. An example could be sdpa_causal_lower_right().
    • However, this could lead to challenges in maintainability and potentially complicate the user experience.
  • Expanding Existing Function with More Arguments:

    • We could enhance the primary SDPA function by adding more arguments to accommodate the desired features.
    • This could result in a bloated function signature reminiscent of nn.mha, which could overwhelm users with a plethora of keyword arguments.
  • Tensor Subclasses

    • This is somewhat analogous to tensor_subclass. We could define __torch_function__ for the individual classes in question and overload the logic for calling into SDPA. Part of me feels like this is more machinery then needed. And I have concerns with the handling with torch.compile. That being said I do think this is the most attractive alternative to the proposed implementation.

Prior Art

This implementation is heavily borrowed/inspired from Xformers. They have a notion of an AttentionBias.I think that the communities reception of this technique is generally positive. And enjoy this abstraction.

How we teach this

  • What names and terminology work best for these concepts and why? How is this idea best presented?
  • Would the acceptance of this proposal mean the PyTorch documentation must be re-organized or altered? Yes
  • How should this feature be taught to existing PyTorch users? I think the hardest thing to teach would be the lambda mask but I think that its design/implementation can be independent of this proposal and would likely be a feature on its own that this API change would enable.

Unresolved Questions

  • Design Clarifications Through the RFC Process:

    • One area of potential improvement is the coexistence of both the causal flag and the AttnBias. This might inadvertently reveal too much about the kernel's implementation to users.
    • While there is an inclination towards retaining this due to its potential utility with the score mod API, it warrants further discussion and consensus.
  • Design Resolutions Before Stabilization:

    • Feedback is sought on the envisioning of user-defined biases. Specifically, how they might be structured and if the proposed API is versatile enough to encompass them.

Addendum

A related issue to AttnBias variant can be found here: #110702

Which tensor subclass?

When attempting to pass a dispatch tensor subclass as an attn_bias to torch.scaled_dot_product_attention() function, we would likely encounter a compilation issue ( prior to Brian's work). Besides this there is also an issue on how AttnBias subclasses would override sdpa's behavior. The primary reason behind this is the composite nature of torch.scaled_dot_product_attention(). When calling the CompositeImplicit torch.nn.scaled_dot_product_attention(), the function will internally call one of sdpa_math, sdpa_mem_eff, sdpa_flash. So what op would the subclass match on?

A solution for the above problem would be to detect for subclass and route to a dummy aten op that could be overloaded. That being said I think this adds complexity and overhead for a performance critical path and I am not sure if it provides more value then the main implementation listed above.

torch_function:
Currently this would not likely work with torch.compile, however there is active work to change this
I wrote up a tensor_subclass version here. It is essentially the same as the one listed above ( uses the same machinery) but does not update sdpa and instead use torch__function to override dispatch.

Indeed these are not torch.compilable today.

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @machellazos.

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @cpuhrsch

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: nnRelated to torch.nnmodule: python frontendFor issues relating to PyTorch's Python frontendmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    To pick up

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions