🐛 Bug
MultiHeadAttention can not run with auto mixed precision mode.
Steps to reproduce the behavior:
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
xla_device = xm.xla_device()
embed_dim = 1024
num_heads = 64
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
input = torch.ones([4,32,1024], dtype=torch.float32).to(xla_device)
attn_mask = torch.ones([32,32], dtype=torch.float32).to(xla_device)
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).to(xla_device)
with torch.amp.autocast("xla", dtype=torch.float16):
attn_output = multihead_attn(input, input, input, attn_mask=attn_mask, need_weights=False)
xm.mark_step()
print(attn_output[0].dtype)
print(attn_output)
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and query.dtype: c10::Half instead.
Expected behavior
MultiHeadAttention module can run successfully and get correct result tensor type.
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
Additional context
Though I reproduce the bug by CPU, but I believe it will occur with any kind of pjrt device except cuda. I can reproduce it on intel gpu also. To solve this bug, we only need to register low precision autocast for scaled dot product attention and has verified it. I want to ask why we don't register this and does there exist any problem?
🐛 Bug
MultiHeadAttention can not run with auto mixed precision mode.
Steps to reproduce the behavior:
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and query.dtype: c10::Half instead.
Expected behavior
MultiHeadAttention module can run successfully and get correct result tensor type.
Environment
Additional context
Though I reproduce the bug by CPU, but I believe it will occur with any kind of pjrt device except cuda. I can reproduce it on intel gpu also. To solve this bug, we only need to register low precision autocast for scaled dot product attention and has verified it. I want to ask why we don't register this and does there exist any problem?