-
Notifications
You must be signed in to change notification settings - Fork 27.4k
scaled_dot_product_attention produces nans with boolean attn_mask with zero rows. #103963
Copy link
Copy link
Closed
Closed
Copy link
Labels
module: edge casesAdversarial inputs unlikely to occur in practiceAdversarial inputs unlikely to occur in practicemodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
As per title. The repro is:
import torch
import torch.nn.functional as F
torch.manual_seed(13)
q = torch.rand(1, 1, 8, 8, device='cuda')
k = torch.rand(1, 1, 8, 8, device='cuda')
v = torch.rand(1, 1, 8, 8, device='cuda')
attn_mask = torch.rand(1, 1, 8, 8, device='cuda')
# make only attn_mask[..., :4, :4] nonzero
attn_mask[..., 4:, :] = 0
attn_mask[..., :, 4:] = 0
attn_mask[..., 4:, 4:] = 0
def run(attn_mask):
return F.scaled_dot_product_attention(q, k, v, attn_mask)
print(run(attn_mask))
print()
print(run(attn_mask.to(torch.bool).to(torch.float)))
print()
print(run(attn_mask.to(torch.bool)))Result:
tensor([[[[0.3294, 0.5102, 0.5570, 0.3931, 0.5681, 0.4310, 0.4364, 0.5873],
[0.3122, 0.5634, 0.5390, 0.4812, 0.5480, 0.5196, 0.3442, 0.5177],
[0.3113, 0.5429, 0.5753, 0.4279, 0.5462, 0.4614, 0.4262, 0.5070],
[0.3308, 0.5409, 0.5552, 0.4361, 0.5701, 0.4857, 0.3748, 0.5113],
[0.3355, 0.4949, 0.5479, 0.4647, 0.5509, 0.4838, 0.3739, 0.5645],
[0.3417, 0.4965, 0.5595, 0.4552, 0.5323, 0.4618, 0.3968, 0.5559],
[0.3436, 0.5026, 0.5588, 0.4464, 0.5421, 0.4609, 0.3855, 0.5657],
[0.3363, 0.4980, 0.5521, 0.4547, 0.5463, 0.4718, 0.3850, 0.5666]]]],
device='cuda:0')
tensor([[[[0.3269, 0.5471, 0.5455, 0.4189, 0.5830, 0.4809, 0.3835, 0.5237],
[0.3242, 0.5440, 0.5454, 0.4198, 0.5904, 0.4857, 0.3798, 0.5267],
[0.3234, 0.5435, 0.5605, 0.4166, 0.5690, 0.4675, 0.4054, 0.5148],
[0.3260, 0.5490, 0.5539, 0.4136, 0.5779, 0.4720, 0.3908, 0.5235],
[0.3355, 0.4949, 0.5479, 0.4647, 0.5509, 0.4838, 0.3739, 0.5645],
[0.3417, 0.4965, 0.5595, 0.4552, 0.5323, 0.4618, 0.3968, 0.5559],
[0.3436, 0.5026, 0.5588, 0.4464, 0.5421, 0.4609, 0.3855, 0.5657],
[0.3363, 0.4980, 0.5521, 0.4547, 0.5463, 0.4718, 0.3850, 0.5666]]]],
device='cuda:0')
tensor([[[[0.3101, 0.6073, 0.5391, 0.3732, 0.6267, 0.4901, 0.3819, 0.4780],
[0.3111, 0.6065, 0.5381, 0.3717, 0.6280, 0.4894, 0.3814, 0.4802],
[0.2993, 0.6003, 0.5579, 0.3733, 0.6151, 0.4797, 0.4109, 0.4651],
[0.3040, 0.6067, 0.5467, 0.3717, 0.6201, 0.4837, 0.3963, 0.4759],
[ nan, nan, nan, nan, nan, nan, nan, nan],
[ nan, nan, nan, nan, nan, nan, nan, nan],
[ nan, nan, nan, nan, nan, nan, nan, nan],
[ nan, nan, nan, nan, nan, nan, nan, nan]]]],
device='cuda:0')
Versions
Current master.
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
module: edge casesAdversarial inputs unlikely to occur in practiceAdversarial inputs unlikely to occur in practicemodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module