-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Supporting Different head dims in FlexAttention #133674
Copy link
Copy link
Closed
Labels
module: flex attentiononcall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
I am running a FlexAttention operation and it returns different output shapes with and without compile. The correct output shapes are those returned without compile.
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
B = 1
H = 8
S = 256
D = 128
D_L = 512
query = torch.randn(
B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=False
)
key = torch.randn(
B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=False
)
value = torch.randn(
B, H, S, D_L, device="cuda", dtype=torch.bfloat16, requires_grad=False
)
def mask_fn(b, h, q_idx, kv_idx):
return (
q_idx == kv_idx
)
def noop(b, h, q_idx, kv_idx):
return True
block_mask = create_block_mask(mask_fn, B=B, H=H, Q_LEN=S, KV_LEN=S)
attn_out = flex_attention(query, key, value, block_mask=block_mask)
print("Attn_out shape without compile", attn_out.shape)
# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)
block_mask = create_block_mask(mask_fn, B=B, H=H, Q_LEN=S, KV_LEN=S)
attn_out = flex_attention(query, key, value, block_mask=block_mask)
print("Attn_out shape with compile", attn_out.shape)Output:
Attn_out shape without compile torch.Size([1, 8, 256, 512])
Attn_out shape with compile torch.Size([1, 8, 256, 128])
Versions
Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.5.0.dev20240814+cu121
[pip3] triton==3.0.0
[conda] numpy 1.26.3 pypi_0 pypi
[conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi
[conda] torch 2.5.0.dev20240814+cu121 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
module: flex attentiononcall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module