Skip to content

Why not register low precision autocast for scaled dot product attention? #7177

@ghost

Description

🐛 Bug

MultiHeadAttention can not run with auto mixed precision mode.

Steps to reproduce the behavior:

import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm

xla_device = xm.xla_device()
embed_dim = 1024
num_heads = 64
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
input = torch.ones([4,32,1024], dtype=torch.float32).to(xla_device)
attn_mask = torch.ones([32,32], dtype=torch.float32).to(xla_device)
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).to(xla_device)
with torch.amp.autocast("xla", dtype=torch.float16):
  attn_output = multihead_attn(input, input, input, attn_mask=attn_mask, need_weights=False)
xm.mark_step()
print(attn_output[0].dtype)
print(attn_output)

RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and query.dtype: c10::Half instead.

Expected behavior

MultiHeadAttention module can run successfully and get correct result tensor type.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU

Additional context

Though I reproduce the bug by CPU, but I believe it will occur with any kind of pjrt device except cuda. I can reproduce it on intel gpu also. To solve this bug, we only need to register low precision autocast for scaled dot product attention and has verified it. I want to ask why we don't register this and does there exist any problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions