Skip to content

[FlexAttention] OutOfResources: out of resource: shared memory #132075

@drisspg

Description

@drisspg

Summary

OutOfResources: out of resource: shared memory, Required: 395520, Hardware limit: 232448. Reducing block sizes or num_stages may help.

Repro

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, or_masks


B = 8
H = 1
S = 1024
query = torch.randn(B, H, S, S, device=torch.device("cuda"), dtype=torch.float32)
key = torch.randn(B, H, S, S, device=torch.device("cuda"), dtype=torch.float32)
value = torch.randn(B, H, S, S, device=torch.device("cuda"), dtype=torch.float32)

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx <= kv_idx

prefix_length = torch.arange(B, device=torch.device("cuda"), dtype=torch.int32)
def prefix_mask(b, h, q_idx, kv_idx):
    return kv_idx <= prefix_length[b]

prefix_lm_causal = or_masks(prefix_mask, causal_mask)
# In this case, as our mask is different per sequence, we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B, H, S, S)


flex_compiled = torch.compile(flex_attention)
out = flex_compiled(query, key, value, block_mask=block_mask)

Metadata

Metadata

Assignees

No one assigned

    Labels

    triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions