@@ -347,6 +347,7 @@ def test_train_with_pad_and_catch_error(self, device):
347347 @parametrize ("key_padding_mask_dim" , [2 , None ])
348348 @parametrize ("mask_dtype" , [torch .bool , torch .float32 ])
349349 def test_multiheadattention_fastpath_attn_mask (self , device , attn_mask_dim , key_padding_mask_dim , mask_dtype ):
350+ # MHA converts all
350351 with torch .no_grad ():
351352 B = 2
352353 L = 4
@@ -356,7 +357,7 @@ def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_
356357 if attn_mask_dim == 2 :
357358 attn_mask = make_tensor ((L , L ), dtype = mask_dtype , device = device )
358359 elif attn_mask_dim == 3 :
359- attn_mask = make_tensor ((B * H , L , L ), dtype = mask_dtype , device = device )
360+ attn_mask = make_tensor ((B , 1 , L , L ), dtype = mask_dtype , device = device ). expand ( B , H , L , L ). reshape ( B * H , L , L )
360361 elif attn_mask_dim is None :
361362 attn_mask = None
362363
@@ -372,7 +373,9 @@ def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_
372373 out , _ = mha (X , X , X , attn_mask = attn_mask , key_padding_mask = key_padding_mask , need_weights = False )
373374 mha .eval () # enable fast path
374375 out_fp , _ = mha (X , X , X , attn_mask = attn_mask , key_padding_mask = key_padding_mask , need_weights = False )
375- self .assertEqual (out , out_fp )
376+ # The FP kernel will return NaNs while the sdpa kernel which is ran when the fast path is turned off returns 0 instead
377+ # of NaNs for fully masked rows
378+ torch .testing .assert_close (out , out_fp .nan_to_num ())
376379
377380 @parametrize ("nhead" , [1 , 4 , 8 ])
378381 def test_transformerencoderlayer_src_mask (self , device , nhead ):
@@ -1156,6 +1159,25 @@ def rand_tensor(*shape):
11561159 else :
11571160 actual = torch .nn .functional .scaled_dot_product_attention (
11581161 query , key , value , attn_mask , dropout_p , is_causal )
1162+ # This test the fully masked out rows case
1163+ if torch .isnan (expected ).any ():
1164+ row_sums = attn_mask .sum (dim = - 1 )
1165+ masked_out_rows = (row_sums == 0 )
1166+
1167+ for _ in range ((input_dim - attn_mask_dim ) - 1 ):
1168+ masked_out_rows = masked_out_rows .unsqueeze (0 )
1169+
1170+ masked_out_rows = masked_out_rows .expand (expected .shape [:- 1 ])
1171+ # Slice out the fully masked rows from expected and actual
1172+ expected_masked_out = expected [masked_out_rows ]
1173+ actual_masked_out = actual [masked_out_rows ]
1174+
1175+ expected_all_nan = torch .isnan (expected_masked_out ).all ()
1176+ actual_all_zero = (actual_masked_out .abs ().sum () == 0 )
1177+
1178+ self .assertTrue (expected_all_nan )
1179+ self .assertTrue (actual_all_zero )
1180+ return
11591181
11601182 self .assertEqual (actual , expected )
11611183
@@ -1961,7 +1983,7 @@ def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: to
19611983 @parametrize ("n_head" , [1 , 3 ])
19621984 @parametrize ("head_dim" , [8 ])
19631985 @parametrize ("mask_dim" , [2 , 4 ])
1964- @parametrize ("bool_mask" , [0 , 1 ])
1986+ @parametrize ("bool_mask" , [False , True ])
19651987 @parametrize ("train" , [True , False ])
19661988 @parametrize ("casual" , [True , False ])
19671989 @parametrize ("set_attn_mask" , [True , False ])
@@ -2036,6 +2058,9 @@ def test_scaled_dot_product_fused_attention_mask_vs_math_cpu(
20362058 if dtype in [torch .bfloat16 , torch .float16 ]:
20372059 math_ref = math_ref .to (dtype )
20382060
2061+ self .assertFalse (torch .isnan (math_ref ).any ())
2062+ self .assertFalse (torch .isnan (actual ).any ())
2063+
20392064 self .assertEqual (actual , math_ref , atol = tol .atol , rtol = tol .rtol )
20402065
20412066 if train :
@@ -2064,6 +2089,104 @@ def test_scaled_dot_product_fused_attention_with_inf(self, device):
20642089 actual = torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = mask )
20652090 self .assertEqual (math_ref , actual )
20662091
2092+ @unittest .skipIf (not PLATFORM_SUPPORTS_FUSED_ATTENTION , "Fused SDPA was not built for this system" )
2093+ @parametrize ("backend" , [SDPBackend .EFFICIENT_ATTENTION , SDPBackend .FLASH_ATTENTION ])
2094+ @parametrize ("seq_len" , [32 , 64 , 128 ])
2095+ @parametrize ("head_dim" , [16 , 32 ])
2096+ @parametrize ("dtype" , [torch .float32 , torch .float16 ])
2097+ def test_fully_masked_out_rows (self , backend , device , seq_len , head_dim , dtype ):
2098+ def attention_inputs (seq_len , head_dim , device , dtype , mask_every_n_rows = 4 ):
2099+ query = torch .rand (1 , 1 , seq_len , head_dim , requires_grad = True , device = device , dtype = dtype )
2100+ key = torch .rand (1 , 1 , seq_len , head_dim , requires_grad = True , device = device , dtype = dtype )
2101+ value = torch .rand (1 , 1 , seq_len , head_dim , requires_grad = True , device = device , dtype = dtype )
2102+
2103+ # Create a mask with deterministic row masking
2104+ mask = torch .ones (1 , 1 , seq_len , seq_len , dtype = torch .bool , device = device )
2105+
2106+ # Mask every nth row
2107+ mask [0 , 0 , ::mask_every_n_rows , :] = False
2108+
2109+ # Create a fixed pattern for element-wise masking
2110+ element_mask = torch .zeros (seq_len , seq_len , dtype = torch .bool , device = device )
2111+ element_mask [torch .arange (seq_len )[:, None ] % 5 == torch .arange (seq_len ) % 5 ] = True
2112+
2113+ # Combine row masking and element-wise masking
2114+ mask = mask & element_mask .unsqueeze (0 ).unsqueeze (0 )
2115+
2116+ return query , key , value , mask
2117+
2118+ def compute_output_and_grads (query , key , value , mask , backend ):
2119+ with sdpa_kernel (backend ):
2120+ masked_out = scaled_dot_product_attention (query , key , value , attn_mask = mask )
2121+ loss = masked_out .sum ()
2122+ grads = torch .autograd .grad (loss , [query , key , value ])
2123+ return masked_out , grads
2124+
2125+ if backend == SDPBackend .FLASH_ATTENTION and "cuda" in str (device ):
2126+ unittest .skip ("FlashAttention does not support masks on cuda" )
2127+ return
2128+ if backend == SDPBackend .EFFICIENT_ATTENTION and "cpu" in str (device ):
2129+ unittest .skip ("EfficientAttention does not support masks on cpu" )
2130+ return
2131+ query , key , value , mask = attention_inputs (seq_len , head_dim , device , dtype )
2132+
2133+ # Compute results for the tested backend
2134+ backend_out , backend_grads = compute_output_and_grads (query , key , value , mask , backend )
2135+
2136+ # Compute results for the Math backend
2137+ math_out , math_grads = compute_output_and_grads (query , key , value , mask , SDPBackend .MATH )
2138+
2139+ # Compare outputs
2140+ torch .testing .assert_close (backend_out , math_out , atol = 5e-3 , rtol = 0 )
2141+ self .assertFalse (backend_out .isnan ().any ())
2142+ self .assertFalse (math_out .isnan ().any ())
2143+ # Compare gradients
2144+ for bg , mg in zip (backend_grads , math_grads ):
2145+ torch .testing .assert_close (bg , mg , atol = 3e-3 , rtol = 0 )
2146+ self .assertFalse (bg .isnan ().any ())
2147+ self .assertFalse (mg .isnan ().any ())
2148+
2149+ # Check if masked rows are zero in output
2150+ mask_sum = mask .sum (dim = - 1 , keepdim = True )
2151+ masked_rows = (mask_sum == 0 ).expand_as (backend_out )
2152+ self .assertTrue ((mask_sum == 0 ).sum () > 0 , "No fully masked out rows found" )
2153+ assert torch .all (backend_out [masked_rows ] == 0 ), \
2154+ f"Non-zero values in fully masked rows for { backend = } "
2155+
2156+ # Check if gradients for masked rows are zero
2157+ grad_query = backend_grads [0 ]
2158+ assert torch .all (grad_query [masked_rows ] == 0 ), f"Non-zero gradients in fully masked rows for { backend = } "
2159+
2160+ @parametrize ("dtype" , [torch .float32 , torch .float16 ])
2161+ @parametrize ("fill_val" , [float ("inf" )])
2162+ def test_non_masked_rows_nan_props (self , device , dtype , fill_val ):
2163+ query = torch .randn (1 , 2 , 4 , 16 , device = device , dtype = dtype )
2164+ # a single NaN in the query input
2165+ query [0 , 1 , 2 , 3 ] = fill_val
2166+ query = query .detach ().requires_grad_ (True )
2167+ key = torch .randn (1 , 2 , 4 , 16 , device = device , dtype = dtype , requires_grad = True )
2168+ value = torch .randn (1 , 2 , 4 , 16 , device = device , dtype = dtype , requires_grad = True )
2169+
2170+ out = torch .nn .functional .scaled_dot_product_attention (query , key , value )
2171+ self .assertTrue (torch .isnan (out ).any ())
2172+ out .sum ().backward ()
2173+ self .assertTrue (torch .isnan (query .grad ).any ())
2174+
2175+ @parametrize ("kernel" , [SDPBackend .MATH ])
2176+ def test_scaled_dot_product_attention_math_with_negative_scale (self , device , kernel : SDPBackend ):
2177+ # https://github.com/pytorch/pytorch/issues/105190.
2178+ def ref (x ):
2179+ v1 = torch .matmul (x , x .transpose (- 1 , - 2 ))
2180+ v2 = v1 / - 0.0001
2181+ v3 = v2 .softmax (dim = - 1 )
2182+ v4 = torch .matmul (v3 , x )
2183+ return v4
2184+
2185+ x = torch .randn (1 , 3 , 64 , 64 , device = device )
2186+ ref_result = ref (x )
2187+ with sdpa_kernel (backends = [kernel ]):
2188+ sdp_math = torch .nn .functional .scaled_dot_product_attention (x , x , x , scale = - 1.0 / 0.0001 )
2189+ self .assertEqual (ref_result , sdp_math )
20672190
20682191class TestSDPACudaOnly (NNTestCase ):
20692192 """ Used to test CUDA only functionality of scaled_dot_product_attention
0 commit comments