Initial private SDP interface and naive composite impl#81956
Initial private SDP interface and naive composite impl#81956jbschlosser wants to merge 11 commits intogh/jbschlosser/44/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 677c21b (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
… impl" Adds an initial version of the SDP interface with: * A flag `need_attn_weights` for indicating that attention weights do not need to be computed; this is useful for flash attention as it does not materialize the weights, making it more expensive to return them * Floating point attention mask support (this needs to change, as flash attention only supports causal masks and doesn't require them to be materialized) [ghstack-poisoned]
… impl" Adds an initial version of the SDP interface with: * A flag `need_attn_weights` for indicating that attention weights do not need to be computed; this is useful for flash attention as it does not materialize the weights, making it more expensive to return them * Floating point attention mask support (this needs to change, as flash attention only supports causal masks and doesn't require them to be materialized) [ghstack-poisoned]
… impl" Adds an initial version of the SDP interface with: * A flag `need_attn_weights` for indicating that attention weights do not need to be computed; this is useful for flash attention as it does not materialize the weights, making it more expensive to return them * Floating point attention mask support (this needs to change, as flash attention only supports causal masks and doesn't require them to be materialized) [ghstack-poisoned]
… impl"
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
[ghstack-poisoned]
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
[ghstack-poisoned]
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
[ghstack-poisoned]
erichan1
left a comment
There was a problem hiding this comment.
The API looks good to me overall. Is there anything we need to do to check Q,K,V inputs as NestedTensors? ie is any mixture of NestedTensor and regular Tensor fine for Q,K,V? Can we arbitrarily pair NestedTensor Q/K/V and attn_mask and/or is_causal=True?
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
[ghstack-poisoned]
|
Hey @jbschlosser. |
|
@pytorchbot revert -m "broke all configs on test_scaled_dot_product_attention (main.TestNestedTensorAutograd) https://hud.pytorch.org/pytorch/pytorch/commit/f15c5bf13387bde01af48f27d44ad853c004308d" -c weird Looks like a landrace but am not sure |
|
@pytorchbot successfully started a revert job. Check the current status here |
|
@jbschlosser your PR has been successfully reverted. |
This reverts commit f15c5bf. Reverted #81956 on behalf of https://github.com/janeyx99 due to broke all configs on test_scaled_dot_product_attention (__main__.TestNestedTensorAutograd) https://hud.pytorch.org/pytorch/pytorch/commit/f15c5bf13387bde01af48f27d44ad853c004308d
Summary: Adds an initial private API version of the SDP interface. Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/f15c5bf13387bde01af48f27d44ad853c004308d Reviewed By: osalpekar Differential Revision: D38226895 Pulled By: jbschlosser fbshipit-source-id: c6b37a42991960b4c5fa98bd4bb6df6d49e3c198
Summary: This reverts commit f15c5bf. Reverted #81956 on behalf of https://github.com/janeyx99 due to broke all configs on test_scaled_dot_product_attention (__main__.TestNestedTensorAutograd) https://hud.pytorch.org/pytorch/pytorch/commit/f15c5bf13387bde01af48f27d44ad853c004308d Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/26776d628cada8988862e9fd4f98afab5cad5d43 Reviewed By: osalpekar Differential Revision: D38227572 Pulled By: osalpekar fbshipit-source-id: eec706e1620a27c2169f4e1d731f26546dc9429d
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
[ghstack-poisoned]
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
[ghstack-poisoned]
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
Summary: Adds an initial private API version of the SDP interface. Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/6ca95547ac216dde9af29003648064b3569ade10 Reviewed By: kit1980 Differential Revision: D38359346 Pulled By: jbschlosser fbshipit-source-id: 9915894e2cbdf3278f080c28a8e1c5fd70f52b14
Stack from ghstack:
Adds an initial private API version of the SDP interface.
Signature:
Returns a tuple of
(output, attn_weights).Note the following:
need_attn_weights: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.Truevalues withinattn_maskindicate that the element should take part in attention (notably, this is reverse of MHA, which usesTrueto mask out values). Mask is optional.is_causal: Temporary flag indicating whether to use a causal attention weighting. If this is set toTrue, it takes precedent over any value passed in forattn_mask. Longer term, theis_causalflagging can be subsumed into theattn_maskarg via tensor subclassing (see e.g. CausalTensor in xFormers).F._scaled_dot_product_attention.