Skip to content

Update fused kernels and call _safe_softmax from SDPA#131863

Closed
drisspg wants to merge 40 commits intogh/drisspg/17/basefrom
gh/drisspg/17/head
Closed

Update fused kernels and call _safe_softmax from SDPA#131863
drisspg wants to merge 40 commits intogh/drisspg/17/basefrom
gh/drisspg/17/head

Conversation

@drisspg
Copy link
Copy Markdown
Contributor

@drisspg drisspg commented Jul 26, 2024

Stack from ghstack (oldest at bottom):

Summary

Changes the stance of SDPA on what to do for fully masked out rows

Current Behavior

Several PyTorch users have expressed frustration over this issue:

These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
#24816 (comment)

Can be paraphrased as follows:

When passing in fully masked out rows, attention becomes ambiguous. We have two main options:

  1. Uniformly attend to all values:

    scores[masked_out_rows] = 1 / len(row)
    out[masked_out_rows] = 1 / len(row) * value
  2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:

    output[fully_masked_rows] = NaN

We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:

>fill_value = -float("inf") 
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])

Cool, problem solved. But what happends when you call backwards..

>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
        [       nan,        nan,        nan,        nan]])

Those pesky NaNs are back!

Why do we see NaNs today?

The core of the problem revolves around using softmax function in sdpa:

> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])

Quick Aside: Masking in Attention

Attention itself doesn't have a concept of masking. The sdpa function has an argument called attn_mask, which would be more accurately named attn_bias. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details(performance), we add a value to the masked-out query/key pairs.

We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.

Alternative Approaches

If we use a very large negative number instead of -inf:

> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])

However if users always remembered to "slice" out their outputs i.e.:

>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564,  0.1613, -0.0486],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])

This would bring us back into a better state.

A Third Option

We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.

This PR implements the new semantic for masking w/ attention in fully masked-out rows:

out[masked_out_rows] = 0

Important Note: This idea isn't entirely new. The MaskedTensor prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.

Details

This PR stack does 3 things:

  1. Adds a PRIVATE _safe_softmax op
  2. Updates semantic for flash_cpu fused kernel
  3. Updates semantic for efficient_cuda fused kernel

_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.

Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?

The only case that this can happen is if the input itself had a NaN or an Inf
For example:

a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))

Will return
tensor([0., 1., 0., 0.], dtype=torch.float16)

Where

a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)

returns:
tensor([nan, nan, nan, nan], dtype=torch.float16)

If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this

max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)

however we would be paying for this in math performance.

Why Now

I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.

cc @ezyang @gchanan @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @Valentine233

Differential Revision: D61418679

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jul 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131863

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ac89132 with merge base 80ed3e9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Jul 26, 2024
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 26, 2024
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@drisspg drisspg requested a review from jgong5 July 26, 2024 17:19
drisspg added 5 commits July 26, 2024 21:32


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]
[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]
drisspg added 13 commits July 29, 2024 09:12


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 

[ghstack-poisoned]
@izaitsevfb
Copy link
Copy Markdown
Contributor

@pytorchbot revert -m "breaks executorch test executorch/backends/apple/coreml:test - test_vit_skip_conv (executorch.backends.apple.coreml.test.test_coreml_partitioner.TestCoreMLPartitioner)" -c ghfirst

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Aug 15, 2024
This reverts commit caba37e.

Reverted #131863 on behalf of https://github.com/izaitsevfb due to breaks executorch test executorch/backends/apple/coreml:test - test_vit_skip_conv (executorch.backends.apple.coreml.test.test_coreml_partitioner.TestCoreMLPartitioner) ([comment](#131863 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@drisspg your PR has been successfully reverted.

@drisspg
Copy link
Copy Markdown
Contributor Author

drisspg commented Aug 16, 2024

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

drisspg added a commit to drisspg/executorch that referenced this pull request Aug 17, 2024
Summary:
X-link: pytorch/pytorch#131863

cc ezyang gchanan jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233

imported-using-ghimport

Reviewed By: Chillee

Differential Revision: D61418679

Pulled By: drisspg
drisspg added a commit to drisspg/executorch that referenced this pull request Aug 17, 2024
Summary:
Pull Request resolved: pytorch#4772

X-link: pytorch/pytorch#131863

cc ezyang gchanan jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D61418679

Pulled By: drisspg
drisspg added a commit to drisspg/executorch that referenced this pull request Aug 17, 2024
Summary:
Pull Request resolved: pytorch#4772

X-link: pytorch/pytorch#131863

cc ezyang gchanan jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D61418679

Pulled By: drisspg
drisspg added a commit to drisspg/executorch that referenced this pull request Aug 17, 2024
Summary:
Pull Request resolved: pytorch#4772

X-link: pytorch/pytorch#131863

cc ezyang gchanan jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D61418679

Pulled By: drisspg
drisspg added a commit to drisspg/executorch that referenced this pull request Aug 19, 2024
Summary:
Pull Request resolved: pytorch#4772

X-link: pytorch/pytorch#131863

cc ezyang gchanan jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D61418679

Pulled By: drisspg
drisspg added a commit that referenced this pull request Aug 19, 2024
# UPDATE: 
This is  take 3 of #131863 which was landed via co dev but not applying correclty

# Summary
Changes the stance of SDPA on what to do for fully masked out rows

## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- #41508
- #103749
- #103963

These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
#24816 (comment)

Can be paraphrased as follows:

When passing in fully masked out rows, attention becomes ambiguous. We have two main options:

1. Uniformly attend to all values:
   ```python
   scores[masked_out_rows] = 1 / len(row)
   out[masked_out_rows] = 1 / len(row) * value
   ```

2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
   ```python
   output[fully_masked_rows] = NaN
   ```

We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf") 
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
        [       nan,        nan,        nan,        nan]])
```
Those pesky NaNs are back!

## Why do we see NaNs today?

The core of the problem revolves around using softmax function in sdpa:

```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```

## Quick Aside: Masking in Attention

Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](#25110 (comment))), we add a value to the masked-out query/key pairs.

We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.

## Alternative Approaches

If we use a very large negative number instead of -inf:

```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564,  0.1613, -0.0486],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
```
This would bring us back into a better state.

## A Third Option

We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.

This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```

**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.

## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel

_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.

Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them? 

The only case that this can happen is if the input itself had a NaN or an Inf
For example: 
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return 
`tensor([0., 1., 0., 0.], dtype=torch.float16)`

Where 
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`

If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this

```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.

## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic. 




cc ezyang gchanan jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233

Differential Revision: [D61418679](https://our.internmc.facebook.com/intern/diff/D61418679)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Aug 19, 2024
# UPDATE:
This is  take 3 of #131863 which was landed via co dev but not applying correclty

# Summary
Changes the stance of SDPA on what to do for fully masked out rows

## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- #41508
- #103749
- #103963

These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
#24816 (comment)

Can be paraphrased as follows:

When passing in fully masked out rows, attention becomes ambiguous. We have two main options:

1. Uniformly attend to all values:
   ```python
   scores[masked_out_rows] = 1 / len(row)
   out[masked_out_rows] = 1 / len(row) * value
   ```

2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
   ```python
   output[fully_masked_rows] = NaN
   ```

We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
        [       nan,        nan,        nan,        nan]])
```
Those pesky NaNs are back!

## Why do we see NaNs today?

The core of the problem revolves around using softmax function in sdpa:

```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```

## Quick Aside: Masking in Attention

Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](#25110 (comment))), we add a value to the masked-out query/key pairs.

We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.

## Alternative Approaches

If we use a very large negative number instead of -inf:

```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564,  0.1613, -0.0486],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
```
This would bring us back into a better state.

## A Third Option

We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.

This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```

**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.

## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel

_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.

Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?

The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`

Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`

If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this

```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.

## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.

Differential Revision: [D61418679](https://our.internmc.facebook.com/intern/diff/D61418679)
Pull Request resolved: #133882
Approved by: https://github.com/soulitzer
@pytorch pytorch deleted a comment from pytorchmergebot Aug 20, 2024
@pytorch pytorch deleted a comment from pytorchmergebot Aug 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: bc-breaking Related to a BC-breaking change module: cpu CPU specific problem (e.g., perf, algorithm) release notes: nn release notes category Reverted topic: bc breaking topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants