Skip to content

Supporting Different head dims in FlexAttention #133674

@kiddyboots216

Description

@kiddyboots216

🐛 Describe the bug

I am running a FlexAttention operation and it returns different output shapes with and without compile. The correct output shapes are those returned without compile.

import torch

from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask

B = 1
H = 8
S = 256
D = 128
D_L = 512

query = torch.randn(
    B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=False
)
key = torch.randn(
    B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=False
)
value = torch.randn(
    B, H, S, D_L, device="cuda", dtype=torch.bfloat16, requires_grad=False
)

def mask_fn(b, h, q_idx, kv_idx):
    return (
        q_idx == kv_idx
    )

def noop(b, h, q_idx, kv_idx):
    return True

block_mask = create_block_mask(mask_fn, B=B, H=H, Q_LEN=S, KV_LEN=S)
attn_out = flex_attention(query, key, value, block_mask=block_mask)
print("Attn_out shape without compile", attn_out.shape)

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)

block_mask = create_block_mask(mask_fn, B=B, H=H, Q_LEN=S, KV_LEN=S)
attn_out = flex_attention(query, key, value, block_mask=block_mask)
print("Attn_out shape with compile", attn_out.shape)

Output:

Attn_out shape without compile torch.Size([1, 8, 256, 512])
Attn_out shape with compile torch.Size([1, 8, 256, 128])

Versions

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.5.0.dev20240814+cu121
[pip3] triton==3.0.0
[conda] numpy 1.26.3 pypi_0 pypi
[conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi
[conda] torch 2.5.0.dev20240814+cu121 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu

Metadata

Metadata

Assignees

Labels

module: flex attentiononcall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions