-
Notifications
You must be signed in to change notification settings - Fork 27.2k
Description
Updated SDPA API
Authors:
Summary
In order for users to more easily manage the complexity handling of various bias formats we would like to expose the ability to pass in AttnBias derived classes to SDPA and have it dispatch to the most optimal kernel. This mechanisms also enable users to extend the behavior of SDPA with their own AttnBias flavors.
Motivation
Managing various bias types and their specializations in SDPA can become intricate for users. By enabling them to pass in custom bias types, we aim to streamline this process. Furthermore, this change would empower features like the score_mod API. Such features have shown considerable performance improvements when used in conjunction with torch.compile and SDPA, especially for specific mask types.
As well this would aid in enabling PyTorch to upgrade its FlashAttention kernel to the newest Implementation. There is currently some limitations regarding BC concerns, see: #108108.
And user motivation for this change: #110144
Proposed Implementation:
The foundational changes to the C++ core are outlined in this PR: #110399.
A prototype implementation can be found in this GitHub repository: drisspg/transformer_nuggets#5. This Prototype implementation would be incorporated into the existing top level nn.functional version of SDPA.
Currently, the interface for invoking the scaled_dot_product_attention looks like:
scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
These changes would update the signature of sdpa to instead of accepting an attn_mask: Tensor accept a derived class from
class AttnBias(ABC):
"""Abstract base class for attention biases"""
@abstractmethod
def materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
raise NotImplementedError("This is an abstract base class")
@abstractmethod
def needs_materialization(self) -> bool:
raise NotImplementedError("This is an abstract base class")
@staticmethod
@abstractmethod
def dispatch(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: "AttnBias",
causal: bool,
scale: Optional[float],
dropout_p: float,
) -> torch.Tensor:
raise NotImplementedError("This is an abstract base class")To ensure that existing functionality remains uninterrupted, we would also introduce popular derived classes, such as:
class TensorBias(AttnBias):
"""A bias that is a tensor"""
def __init__(self, bias: torch.Tensor):
self.bias = bias
def materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
return self.bias
def needs_materialization(self) -> bool:
return True
@staticmethod
def dispatch(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: "TensorBias",
causal: bool,
scale: Optional[float],
dropout_p: float,
) -> torch.Tensor:
raise NotImplementedError(
"TensorBias requires materialization, so this should never be called!"
)
def __repr__(self) -> str:
return f"TensorBias(bias={self.bias})"
CausalBias Variants
An example that would utilize all features of the AttnBias class can be found Here
Changes to torch.nn.functional.scaled_dot_product_attention()
A quick summary along with implementation can be found below:
- If the passed
attn_maskis a tensor, the old method of computation is followed, along with a warning. - If no
attn_maskis provided, the function computes without any mask. - If an
attn_maskderived fromAttnBiasis provided and requires materialization (likeTensorBias), the function materializes it and then computes. - If an
attn_maskderived fromAttnBiasdoesn't need materialization, it uses thedispatchmethod of that mask class to handle the computation.
def sdpa_prototype(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[AttnBias] = None,
causal: bool = False,
scale: Optional[float] = None,
dropout_p: float = 0.0,
):
assert attn_mask is None or isinstance(attn_mask, (AttnBias, torch.Tensor))
if isinstance(attn_mask, torch.Tensor):
warn("Passing a tensor as an attn_mask is deprecated. Please use TensorBias instead.")
return scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, causal, scale)
if attn_mask is None or attn_mask.needs_materialization():
return scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask.materialize(query.device) if attn_mask else None,
dropout_p=dropout_p,
is_causal=causal,
scale=scale,
)
# After this point all AttnBias are required to have defined their own dispatching logic
return attn_mask.dispatch(query, key, value, attn_mask, causal, scale, dropout_p)Metrics
- Performance Measure the performance improvement for popular additive biases like Alibi
- Kernel Dispatch Overhead: Quantify the overhead introduced due to the new dispatch mechanism. Ideally, this overhead should be negligible.
- API Adoption Rate: Track how many developers switch to using the new API over a specified time frame. A higher adoption rate signifies a positive reception.
- User Feedback: Collect qualitative feedback from developers. This can provide insights into the usability and potential areas of improvement.
Drawbacks
Please consider:
- This is not immediately user breaking but we would enable a deprecation cycle for passing in regular tensors.
- I think this would meaningfully improve the UX but would also introduce some complexity
- This would add a non trivial amount of new code to an existing HIGHLY exercised code path, and with that could introduce new bugs. However part of this changes enables much more of the logic to be completely in Python which users find much more understandable and easier to debug then C++.
Alternatives
-
Introducing Specialized Functions:
- One approach is to introduce specialized functions similar to
scaled_dot_product_attention. An example could besdpa_causal_lower_right(). - However, this could lead to challenges in maintainability and potentially complicate the user experience.
- One approach is to introduce specialized functions similar to
-
Expanding Existing Function with More Arguments:
- We could enhance the primary
SDPAfunction by adding more arguments to accommodate the desired features. - This could result in a bloated function signature reminiscent of
nn.mha, which could overwhelm users with a plethora of keyword arguments.
- We could enhance the primary
-
Tensor Subclasses
- This is somewhat analogous to tensor_subclass. We could define
__torch_function__for the individual classes in question and overload the logic for calling into SDPA. Part of me feels like this is more machinery then needed. And I have concerns with the handling with torch.compile. That being said I do think this is the most attractive alternative to the proposed implementation.
- This is somewhat analogous to tensor_subclass. We could define
Prior Art
This implementation is heavily borrowed/inspired from Xformers. They have a notion of an AttentionBias.I think that the communities reception of this technique is generally positive. And enjoy this abstraction.
How we teach this
- What names and terminology work best for these concepts and why? How is this idea best presented?
- I think by updating the function documentation and providing examples of how users can interact with a kvcache using the causal variants in a tutorial would be helpful. Similar to what this user was asking for here: The masking matrix seems to be wrong in the "torch.nn.functional.scaled_dot_product_attention" function #110144
- Would the acceptance of this proposal mean the PyTorch documentation must be re-organized or altered? Yes
- How should this feature be taught to existing PyTorch users? I think the hardest thing to teach would be the lambda mask but I think that its design/implementation can be independent of this proposal and would likely be a feature on its own that this API change would enable.
Unresolved Questions
-
Design Clarifications Through the RFC Process:
- One area of potential improvement is the coexistence of both the
causalflag and theAttnBias. This might inadvertently reveal too much about the kernel's implementation to users. - While there is an inclination towards retaining this due to its potential utility with the
score modAPI, it warrants further discussion and consensus.
- One area of potential improvement is the coexistence of both the
-
Design Resolutions Before Stabilization:
- Feedback is sought on the envisioning of user-defined biases. Specifically, how they might be structured and if the proposed API is versatile enough to encompass them.
Addendum
A related issue to AttnBias variant can be found here: #110702
Which tensor subclass?
When attempting to pass a dispatch tensor subclass as an attn_bias to torch.scaled_dot_product_attention() function, we would likely encounter a compilation issue ( prior to Brian's work). Besides this there is also an issue on how AttnBias subclasses would override sdpa's behavior. The primary reason behind this is the composite nature of torch.scaled_dot_product_attention(). When calling the CompositeImplicit torch.nn.scaled_dot_product_attention(), the function will internally call one of sdpa_math, sdpa_mem_eff, sdpa_flash. So what op would the subclass match on?
A solution for the above problem would be to detect for subclass and route to a dummy aten op that could be overloaded. That being said I think this adds complexity and overhead for a performance critical path and I am not sure if it provides more value then the main implementation listed above.
torch_function:
Currently this would not likely work with torch.compile, however there is active work to change this
I wrote up a tensor_subclass version here. It is essentially the same as the one listed above ( uses the same machinery) but does not update sdpa and instead use torch__function to override dispatch.
Indeed these are not torch.compilable today.
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @machellazos.
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @cpuhrsch
Metadata
Metadata
Assignees
Labels
Type
Projects
Status