Skip to content

Commit caba37e

Browse files
drisspgpytorchmergebot
authored andcommitted
Update fused kernels and call _safe_softmax from SDPA (#131863)
Pull Request resolved: #131863 Approved by: https://github.com/jbschlosser, https://github.com/Chillee
1 parent 9de023d commit caba37e

13 files changed

Lines changed: 194 additions & 19 deletions

File tree

aten/src/ATen/native/cpu/FlashAttentionKernel.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,15 @@ void cpu_flash_attention(
452452
dst_data,
453453
headSize);
454454
}
455+
455456
// dst <- dst / sum[row]
456457
// reorder MHA output with strides
457458
for (int64_t row = 0; row < qBlockSize; ++row) {
459+
// Row sums for full masked out rows are 0, we set them to 1
460+
// in order to avoid NaNs in the output and instead set fully
461+
// masked out rows to 0
462+
qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
463+
qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
458464
accum_t sum_reciprocal = 1 / qk_sum_data[row];
459465
vec::map<scalar_t>(
460466
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8890,6 +8890,7 @@
88908890
variants: method, function
88918891
dispatch:
88928892
QuantizedCPU: eq_quantized_cpu
8893+
NestedTensorCPU, NestedTensorCUDA: eq_tensor_nested
88938894
tags: [core, pointwise]
88948895

88958896
- func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)

aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,5 +322,14 @@ Tensor eq_scalar_nested(const Tensor& self, const Scalar& other) {
322322
});
323323
}
324324

325+
Tensor eq_tensor_nested(const Tensor& self, const Tensor& other) {
326+
TORCH_CHECK(!other.is_nested(), "eq does not support nested tensor as other value.");
327+
return NestedTensor_elementwise_Tensor(
328+
self, other, "eq", false /*supports_striding*/,
329+
[](const Tensor& b1, const Tensor& b2) {
330+
return b1.eq(b2);
331+
});
332+
}
333+
325334
} // namespace native
326335
} // namespace at

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,9 +647,11 @@ Tensor _safe_softmax(
647647
int64_t dim,
648648
std::optional<ScalarType> dtype) {
649649
auto out = at::softmax(self, dim, dtype);
650-
const auto masked = self.eq(-std::numeric_limits<float>::infinity());
650+
const auto neg_inf = at::scalar_tensor(-std::numeric_limits<float>::infinity(), at::TensorOptions().dtype(out.dtype()).device(out.device()));
651+
const auto masked = self.eq(neg_inf);
651652
const auto masked_rows = all(masked, dim, true);
652-
return at::where(masked_rows, at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device())), out);
653+
const auto zero = at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device()));
654+
return at::where(masked_rows, zero, out);
653655
}
654656
// Computes scaled dot product attention on query, key and value tensors, using
655657
// an optional attention mask if passed, and applying dropout if a probability
@@ -837,7 +839,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
837839
attn.add_(*attn_mask);
838840
}
839841
}
840-
attn = at::softmax(attn, -1);
842+
attn = at::_safe_softmax(attn, -1);
841843
if (dropout_p > 0.0) {
842844
if (dropout_mask.has_value()) {
843845
// In order to validate the correctness of the fused kernels, we need to

aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,10 @@ class MemoryEfficientAttentionNormalize {
144144
multiplies<ComputeFragment> mul_add_source;
145145
multiply_add<ComputeFragment> mul_add_accumulator;
146146

147-
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
147+
// Row sums for full masked out rows are 0, we set them to 1
148+
// In order to avoid NaNs in the output and instead sem them to 0.
149+
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
150+
ElementCompute alpha = isLast ? (1 / denom) : 1;
148151
ElementCompute beta = alpha * m_prime_[row];
149152

150153
intermediate = mul_add_source(beta, converted_source); // X = beta * C
@@ -174,7 +177,10 @@ class MemoryEfficientAttentionNormalize {
174177
ComputeFragment intermediate;
175178
multiplies<ComputeFragment> mul_accumulator;
176179

177-
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
180+
// Row sums for full masked out rows are 0, we set them to 1
181+
// In order to avoid NaNs in the output and instead sem them to 0.
182+
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
183+
ElementCompute alpha = isLast ? (1 / denom) : 1;
178184

179185
intermediate = mul_accumulator(
180186
alpha, converted_accumulator); // X = alpha * C + uniform

aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,10 @@ struct AttentionKernel {
11661166
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
11671167
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
11681168
if (thread_id() < p.num_queries) {
1169+
// We set fully masked out rows to 0, the sumexp for masked out rows will be 0
1170+
// We update it to be 1 prior to calling log so that log(1) = 0
1171+
s_prime[thread_id()] = (s_prime[thread_id()] == 0) ? 1: s_prime[thread_id()];
1172+
mi[thread_id()] = (mi[thread_id()] == -cutlass::platform::numeric_limits<accum_t>::infinity()) ? 0: mi[thread_id()];
11691173
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
11701174
cutlass::fast_log(accum_t(s_prime[thread_id()]));
11711175
} else if (thread_id() < lse_dim) {

test/functorch/test_ops.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,9 +1791,6 @@ def get_vjp(cotangents, *primals):
17911791
), # NYI: forward-AD for soft_margin_loss_backward
17921792
xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss
17931793
xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward
1794-
xfail(
1795-
"torch.ops.aten._safe_softmax.default"
1796-
), # NYI: forward-AD for _safe_softmax
17971794
skip("nn.functional.scaled_dot_product_attention"),
17981795
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
17991796
xfail(
@@ -1976,9 +1973,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
19761973
xfail(
19771974
"nn.functional.ctc_loss"
19781975
), # ForwardAD not implemented and no decomposition
1979-
xfail(
1980-
"torch.ops.aten._safe_softmax.default"
1981-
), # ForwardAD not implemented
19821976
xfail("nn.functional.dropout2d"), # calls random op
19831977
xfail("nn.functional.dropout3d"), # calls random op
19841978
xfail("nn.functional.dropout"), # calls random op

test/test_nn.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12385,12 +12385,20 @@ def perm_fn(x):
1238512385
result = model(encoder_input, src_key_padding_mask=mask)
1238612386
self.assertEqual(result.shape, ref_output.shape)
1238712387
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12388-
# 1 values are masked. Since there is only 1 input embedding this
12389-
# will result in nan.
1239012388
mask = torch.tensor([[1]], device=device) == 1
1239112389
result = model(encoder_input, src_key_padding_mask=mask)
12390+
fast_path_device = result.is_cuda or result.is_cpu
1239212391
result = result.cpu().detach().numpy()
12393-
self.assertTrue(np.isnan(result).all())
12392+
# Non Fast Paths
12393+
if training or not batch_first or TEST_WITH_CROSSREF or not fast_path_device:
12394+
# We changed the semenatic, on the non fast path so that fully masked out rows return
12395+
# 0 from attention thus NaNs should no longer be present and the output should be nonzero
12396+
# due to skip connections
12397+
self.assertTrue(not np.isnan(result).any())
12398+
else:
12399+
# Fast Paths
12400+
self.assertTrue(np.isnan(result).all())
12401+
1239412402

1239512403
# deterministic input
1239612404
encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],

test/test_transformers.py

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def test_train_with_pad_and_catch_error(self, device):
347347
@parametrize("key_padding_mask_dim", [2, None])
348348
@parametrize("mask_dtype", [torch.bool, torch.float32])
349349
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
350+
# MHA converts all
350351
with torch.no_grad():
351352
B = 2
352353
L = 4
@@ -356,7 +357,7 @@ def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_
356357
if attn_mask_dim == 2:
357358
attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device)
358359
elif attn_mask_dim == 3:
359-
attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device)
360+
attn_mask = make_tensor((B, 1, L, L), dtype=mask_dtype, device=device).expand(B, H, L, L).reshape(B * H, L, L)
360361
elif attn_mask_dim is None:
361362
attn_mask = None
362363

@@ -372,7 +373,9 @@ def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_
372373
out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
373374
mha.eval() # enable fast path
374375
out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
375-
self.assertEqual(out, out_fp)
376+
# The FP kernel will return NaNs while the sdpa kernel which is ran when the fast path is turned off returns 0 instead
377+
# of NaNs for fully masked rows
378+
torch.testing.assert_close(out, out_fp.nan_to_num())
376379

377380
@parametrize("nhead", [1, 4, 8])
378381
def test_transformerencoderlayer_src_mask(self, device, nhead):
@@ -1156,6 +1159,25 @@ def rand_tensor(*shape):
11561159
else:
11571160
actual = torch.nn.functional.scaled_dot_product_attention(
11581161
query, key, value, attn_mask, dropout_p, is_causal)
1162+
# This test the fully masked out rows case
1163+
if torch.isnan(expected).any():
1164+
row_sums = attn_mask.sum(dim=-1)
1165+
masked_out_rows = (row_sums == 0)
1166+
1167+
for _ in range((input_dim - attn_mask_dim) - 1):
1168+
masked_out_rows = masked_out_rows.unsqueeze(0)
1169+
1170+
masked_out_rows = masked_out_rows.expand(expected.shape[:-1])
1171+
# Slice out the fully masked rows from expected and actual
1172+
expected_masked_out = expected[masked_out_rows]
1173+
actual_masked_out = actual[masked_out_rows]
1174+
1175+
expected_all_nan = torch.isnan(expected_masked_out).all()
1176+
actual_all_zero = (actual_masked_out.abs().sum() == 0)
1177+
1178+
self.assertTrue(expected_all_nan)
1179+
self.assertTrue(actual_all_zero)
1180+
return
11591181

11601182
self.assertEqual(actual, expected)
11611183

@@ -1961,7 +1983,7 @@ def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: to
19611983
@parametrize("n_head", [1, 3])
19621984
@parametrize("head_dim", [8])
19631985
@parametrize("mask_dim", [2, 4])
1964-
@parametrize("bool_mask", [0, 1])
1986+
@parametrize("bool_mask", [False, True])
19651987
@parametrize("train", [True, False])
19661988
@parametrize("casual", [True, False])
19671989
@parametrize("set_attn_mask", [True, False])
@@ -2036,6 +2058,9 @@ def test_scaled_dot_product_fused_attention_mask_vs_math_cpu(
20362058
if dtype in [torch.bfloat16, torch.float16]:
20372059
math_ref = math_ref.to(dtype)
20382060

2061+
self.assertFalse(torch.isnan(math_ref).any())
2062+
self.assertFalse(torch.isnan(actual).any())
2063+
20392064
self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
20402065

20412066
if train:
@@ -2064,6 +2089,104 @@ def test_scaled_dot_product_fused_attention_with_inf(self, device):
20642089
actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
20652090
self.assertEqual(math_ref, actual)
20662091

2092+
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2093+
@parametrize("backend", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION])
2094+
@parametrize("seq_len", [32, 64, 128])
2095+
@parametrize("head_dim", [16, 32])
2096+
@parametrize("dtype", [torch.float32, torch.float16])
2097+
def test_fully_masked_out_rows(self, backend, device, seq_len, head_dim, dtype):
2098+
def attention_inputs(seq_len, head_dim, device, dtype, mask_every_n_rows=4):
2099+
query = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
2100+
key = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
2101+
value = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
2102+
2103+
# Create a mask with deterministic row masking
2104+
mask = torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool, device=device)
2105+
2106+
# Mask every nth row
2107+
mask[0, 0, ::mask_every_n_rows, :] = False
2108+
2109+
# Create a fixed pattern for element-wise masking
2110+
element_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)
2111+
element_mask[torch.arange(seq_len)[:, None] % 5 == torch.arange(seq_len) % 5] = True
2112+
2113+
# Combine row masking and element-wise masking
2114+
mask = mask & element_mask.unsqueeze(0).unsqueeze(0)
2115+
2116+
return query, key, value, mask
2117+
2118+
def compute_output_and_grads(query, key, value, mask, backend):
2119+
with sdpa_kernel(backend):
2120+
masked_out = scaled_dot_product_attention(query, key, value, attn_mask=mask)
2121+
loss = masked_out.sum()
2122+
grads = torch.autograd.grad(loss, [query, key, value])
2123+
return masked_out, grads
2124+
2125+
if backend == SDPBackend.FLASH_ATTENTION and "cuda" in str(device):
2126+
unittest.skip("FlashAttention does not support masks on cuda")
2127+
return
2128+
if backend == SDPBackend.EFFICIENT_ATTENTION and "cpu" in str(device):
2129+
unittest.skip("EfficientAttention does not support masks on cpu")
2130+
return
2131+
query, key, value, mask = attention_inputs(seq_len, head_dim, device, dtype)
2132+
2133+
# Compute results for the tested backend
2134+
backend_out, backend_grads = compute_output_and_grads(query, key, value, mask, backend)
2135+
2136+
# Compute results for the Math backend
2137+
math_out, math_grads = compute_output_and_grads(query, key, value, mask, SDPBackend.MATH)
2138+
2139+
# Compare outputs
2140+
torch.testing.assert_close(backend_out, math_out, atol=5e-3, rtol=0)
2141+
self.assertFalse(backend_out.isnan().any())
2142+
self.assertFalse(math_out.isnan().any())
2143+
# Compare gradients
2144+
for bg, mg in zip(backend_grads, math_grads):
2145+
torch.testing.assert_close(bg, mg, atol=3e-3, rtol=0)
2146+
self.assertFalse(bg.isnan().any())
2147+
self.assertFalse(mg.isnan().any())
2148+
2149+
# Check if masked rows are zero in output
2150+
mask_sum = mask.sum(dim=-1, keepdim=True)
2151+
masked_rows = (mask_sum == 0).expand_as(backend_out)
2152+
self.assertTrue((mask_sum == 0).sum() > 0, "No fully masked out rows found")
2153+
assert torch.all(backend_out[masked_rows] == 0), \
2154+
f"Non-zero values in fully masked rows for {backend=}"
2155+
2156+
# Check if gradients for masked rows are zero
2157+
grad_query = backend_grads[0]
2158+
assert torch.all(grad_query[masked_rows] == 0), f"Non-zero gradients in fully masked rows for {backend=}"
2159+
2160+
@parametrize("dtype", [torch.float32, torch.float16])
2161+
@parametrize("fill_val", [float("inf")])
2162+
def test_non_masked_rows_nan_props(self, device, dtype, fill_val):
2163+
query = torch.randn(1, 2, 4, 16, device=device, dtype=dtype)
2164+
# a single NaN in the query input
2165+
query[0, 1, 2, 3] = fill_val
2166+
query = query.detach().requires_grad_(True)
2167+
key = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True)
2168+
value = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True)
2169+
2170+
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
2171+
self.assertTrue(torch.isnan(out).any())
2172+
out.sum().backward()
2173+
self.assertTrue(torch.isnan(query.grad).any())
2174+
2175+
@parametrize("kernel", [SDPBackend.MATH])
2176+
def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend):
2177+
# https://github.com/pytorch/pytorch/issues/105190.
2178+
def ref(x):
2179+
v1 = torch.matmul(x, x.transpose(-1, -2))
2180+
v2 = v1 / -0.0001
2181+
v3 = v2.softmax(dim=-1)
2182+
v4 = torch.matmul(v3, x)
2183+
return v4
2184+
2185+
x = torch.randn(1, 3, 64, 64, device=device)
2186+
ref_result = ref(x)
2187+
with sdpa_kernel(backends=[kernel]):
2188+
sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
2189+
self.assertEqual(ref_result, sdp_math)
20672190

20682191
class TestSDPACudaOnly(NNTestCase):
20692192
""" Used to test CUDA only functionality of scaled_dot_product_attention

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2845,6 +2845,7 @@
28452845
# Transformer
28462846
- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
28472847
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
2848+
result: result * (self_t - safe_logsumexp_jvp(self_p, self_t, {dim}, true))
28482849

28492850
- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
28502851
output_differentiability: [True, False, False, False]

0 commit comments

Comments
 (0)