Skip to content

Initial private SDP interface and naive composite impl#81956

Closed
jbschlosser wants to merge 11 commits intogh/jbschlosser/44/basefrom
gh/jbschlosser/44/head
Closed

Initial private SDP interface and naive composite impl#81956
jbschlosser wants to merge 11 commits intogh/jbschlosser/44/basefrom
gh/jbschlosser/44/head

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Jul 22, 2022

Stack from ghstack:

Adds an initial private API version of the SDP interface.

Signature:

_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
    float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)

Returns a tuple of (output, attn_weights).

Note the following:

  • need_attn_weights: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
  • Boolean attention mask support only; True values within attn_mask indicate that the element should take part in attention (notably, this is reverse of MHA, which uses True to mask out values). Mask is optional.
  • is_causal: Temporary flag indicating whether to use a causal attention weighting. If this is set to True, it takes precedent over any value passed in for attn_mask. Longer term, the is_causal flagging can be subsumed into the attn_mask arg via tensor subclassing (see e.g. CausalTensor in xFormers).
  • Testing is currently done via reference with the existing Python impl of F._scaled_dot_product_attention.
  • This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.

@jbschlosser jbschlosser requested a review from bdhirsh as a code owner July 22, 2022 00:00
jbschlosser added a commit that referenced this pull request Jul 22, 2022
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 22, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 677c21b (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@jbschlosser jbschlosser removed the request for review from bdhirsh July 22, 2022 00:01
… impl"


Adds an initial version of the SDP interface with:
* A flag `need_attn_weights` for indicating that attention weights do not need to be computed; this is useful for flash attention as it does not materialize the weights, making it more expensive to return them
* Floating point attention mask support (this needs to change, as flash attention only supports causal masks and doesn't require them to be materialized)

[ghstack-poisoned]
… impl"


Adds an initial version of the SDP interface with:
* A flag `need_attn_weights` for indicating that attention weights do not need to be computed; this is useful for flash attention as it does not materialize the weights, making it more expensive to return them
* Floating point attention mask support (this needs to change, as flash attention only supports causal masks and doesn't require them to be materialized)

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 22, 2022
… impl"


Adds an initial version of the SDP interface with:
* A flag `need_attn_weights` for indicating that attention weights do not need to be computed; this is useful for flash attention as it does not materialize the weights, making it more expensive to return them
* Floating point attention mask support (this needs to change, as flash attention only supports causal masks and doesn't require them to be materialized)

[ghstack-poisoned]
… impl"


Adds an initial private API version of the SDP interface.

Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
    float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```

Returns a tuple of `(output, attn_weights)`.

Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 22, 2022
@jbschlosser jbschlosser changed the title Initial SDP interface with float masks and naive composite impl Initial private SDP interface and naive composite impl Jul 22, 2022
@zrphercule zrphercule self-requested a review July 22, 2022 20:47
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Adds an initial private API version of the SDP interface.

Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
    float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```

Returns a tuple of `(output, attn_weights)`.

Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.

[ghstack-poisoned]
Adds an initial private API version of the SDP interface.

Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
    float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```

Returns a tuple of `(output, attn_weights)`.

Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.

[ghstack-poisoned]
Copy link
Contributor

@erichan1 erichan1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API looks good to me overall. Is there anything we need to do to check Q,K,V inputs as NestedTensors? ie is any mixture of NestedTensor and regular Tensor fine for Q,K,V? Can we arbitrarily pair NestedTensor Q/K/V and attn_mask and/or is_causal=True?

Adds an initial private API version of the SDP interface.

Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
    float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```

Returns a tuple of `(output, attn_weights)`.

Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Jul 25, 2022
@jbschlosser jbschlosser added the release notes: nested tensor Changes that have a direct impact on nested tensors label Jul 27, 2022
@github-actions
Copy link
Contributor

Hey @jbschlosser.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@jbschlosser jbschlosser added topic: improvements topic category topic: not user facing topic category release notes: nn release notes category and removed release notes: nested tensor Changes that have a direct impact on nested tensors topic: improvements topic category labels Jul 27, 2022
@janeyx99
Copy link
Contributor

janeyx99 commented Jul 27, 2022

@pytorchbot revert -m "broke all configs on test_scaled_dot_product_attention (main.TestNestedTensorAutograd) https://hud.pytorch.org/pytorch/pytorch/commit/f15c5bf13387bde01af48f27d44ad853c004308d" -c weird

Looks like a landrace but am not sure

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

@jbschlosser your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Jul 27, 2022
This reverts commit f15c5bf.

Reverted #81956 on behalf of https://github.com/janeyx99 due to broke all configs on test_scaled_dot_product_attention (__main__.TestNestedTensorAutograd) https://hud.pytorch.org/pytorch/pytorch/commit/f15c5bf13387bde01af48f27d44ad853c004308d
@jbschlosser
Copy link
Contributor Author

@janeyx99 This is because this PR depends on transpose for nested tensors, and that PR #80981 was reverted.

@jbschlosser jbschlosser reopened this Jul 27, 2022
facebook-github-bot pushed a commit that referenced this pull request Jul 28, 2022
Summary: Adds an initial private API version of the SDP interface.

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/f15c5bf13387bde01af48f27d44ad853c004308d

Reviewed By: osalpekar

Differential Revision: D38226895

Pulled By: jbschlosser

fbshipit-source-id: c6b37a42991960b4c5fa98bd4bb6df6d49e3c198
facebook-github-bot pushed a commit that referenced this pull request Jul 28, 2022
Summary:
This reverts commit f15c5bf.

Reverted #81956 on behalf of https://github.com/janeyx99 due to broke all configs on test_scaled_dot_product_attention (__main__.TestNestedTensorAutograd) https://hud.pytorch.org/pytorch/pytorch/commit/f15c5bf13387bde01af48f27d44ad853c004308d

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/26776d628cada8988862e9fd4f98afab5cad5d43

Reviewed By: osalpekar

Differential Revision: D38227572

Pulled By: osalpekar

fbshipit-source-id: eec706e1620a27c2169f4e1d731f26546dc9429d
Adds an initial private API version of the SDP interface.

Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
    float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```

Returns a tuple of `(output, attn_weights)`.

Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Aug 1, 2022
ghstack-source-id: 9c1060e
Pull Request resolved: #81956
Adds an initial private API version of the SDP interface.

Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
    float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```

Returns a tuple of `(output, attn_weights)`.

Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Aug 1, 2022
ghstack-source-id: c1fa943
Pull Request resolved: #81956
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

facebook-github-bot pushed a commit that referenced this pull request Aug 3, 2022
Summary: Adds an initial private API version of the SDP interface.

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/6ca95547ac216dde9af29003648064b3569ade10

Reviewed By: kit1980

Differential Revision: D38359346

Pulled By: jbschlosser

fbshipit-source-id: 9915894e2cbdf3278f080c28a8e1c5fd70f52b14
@facebook-github-bot facebook-github-bot deleted the gh/jbschlosser/44/head branch August 5, 2022 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants