Summary
OutOfResources: out of resource: shared memory, Required: 395520, Hardware limit: 232448. Reducing block sizes or num_stages may help.
Repro
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, or_masks
B = 8
H = 1
S = 1024
query = torch.randn(B, H, S, S, device=torch.device("cuda"), dtype=torch.float32)
key = torch.randn(B, H, S, S, device=torch.device("cuda"), dtype=torch.float32)
value = torch.randn(B, H, S, S, device=torch.device("cuda"), dtype=torch.float32)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx <= kv_idx
prefix_length = torch.arange(B, device=torch.device("cuda"), dtype=torch.int32)
def prefix_mask(b, h, q_idx, kv_idx):
return kv_idx <= prefix_length[b]
prefix_lm_causal = or_masks(prefix_mask, causal_mask)
# In this case, as our mask is different per sequence, we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B, H, S, S)
flex_compiled = torch.compile(flex_attention)
out = flex_compiled(query, key, value, block_mask=block_mask)
Summary
OutOfResources: out of resource: shared memory, Required: 395520, Hardware limit: 232448. Reducing block sizes or
num_stagesmay help.Repro