Skip to content

scaled_dot_product_attention produces nans with boolean attn_mask with zero rows. #103963

@nikitaved

Description

@nikitaved

🐛 Describe the bug

As per title. The repro is:

import torch
import torch.nn.functional as F

torch.manual_seed(13)

q = torch.rand(1, 1, 8, 8, device='cuda')
k = torch.rand(1, 1, 8, 8, device='cuda')
v = torch.rand(1, 1, 8, 8, device='cuda')

attn_mask = torch.rand(1, 1, 8, 8, device='cuda')
# make only attn_mask[..., :4, :4] nonzero
attn_mask[..., 4:, :] = 0
attn_mask[..., :, 4:] = 0
attn_mask[..., 4:, 4:] = 0

def run(attn_mask):
    return F.scaled_dot_product_attention(q, k, v, attn_mask)

print(run(attn_mask))
print()
print(run(attn_mask.to(torch.bool).to(torch.float)))
print()
print(run(attn_mask.to(torch.bool)))

Result:

tensor([[[[0.3294, 0.5102, 0.5570, 0.3931, 0.5681, 0.4310, 0.4364, 0.5873],
          [0.3122, 0.5634, 0.5390, 0.4812, 0.5480, 0.5196, 0.3442, 0.5177],
          [0.3113, 0.5429, 0.5753, 0.4279, 0.5462, 0.4614, 0.4262, 0.5070],
          [0.3308, 0.5409, 0.5552, 0.4361, 0.5701, 0.4857, 0.3748, 0.5113],
          [0.3355, 0.4949, 0.5479, 0.4647, 0.5509, 0.4838, 0.3739, 0.5645],
          [0.3417, 0.4965, 0.5595, 0.4552, 0.5323, 0.4618, 0.3968, 0.5559],
          [0.3436, 0.5026, 0.5588, 0.4464, 0.5421, 0.4609, 0.3855, 0.5657],
          [0.3363, 0.4980, 0.5521, 0.4547, 0.5463, 0.4718, 0.3850, 0.5666]]]],
       device='cuda:0')

tensor([[[[0.3269, 0.5471, 0.5455, 0.4189, 0.5830, 0.4809, 0.3835, 0.5237],
          [0.3242, 0.5440, 0.5454, 0.4198, 0.5904, 0.4857, 0.3798, 0.5267],
          [0.3234, 0.5435, 0.5605, 0.4166, 0.5690, 0.4675, 0.4054, 0.5148],
          [0.3260, 0.5490, 0.5539, 0.4136, 0.5779, 0.4720, 0.3908, 0.5235],
          [0.3355, 0.4949, 0.5479, 0.4647, 0.5509, 0.4838, 0.3739, 0.5645],
          [0.3417, 0.4965, 0.5595, 0.4552, 0.5323, 0.4618, 0.3968, 0.5559],
          [0.3436, 0.5026, 0.5588, 0.4464, 0.5421, 0.4609, 0.3855, 0.5657],
          [0.3363, 0.4980, 0.5521, 0.4547, 0.5463, 0.4718, 0.3850, 0.5666]]]],
       device='cuda:0')

tensor([[[[0.3101, 0.6073, 0.5391, 0.3732, 0.6267, 0.4901, 0.3819, 0.4780],
          [0.3111, 0.6065, 0.5381, 0.3717, 0.6280, 0.4894, 0.3814, 0.4802],
          [0.2993, 0.6003, 0.5579, 0.3733, 0.6151, 0.4797, 0.4109, 0.4651],
          [0.3040, 0.6067, 0.5467, 0.3717, 0.6201, 0.4837, 0.3963, 0.4759],
          [   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan],
          [   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan],
          [   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan],
          [   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan]]]],
       device='cuda:0')

Versions

Current master.

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: edge casesAdversarial inputs unlikely to occur in practicemodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions