Skip to content

Shared memory out of resource when using flex attention #133254

@TechxGenus

Description

@TechxGenus

🐛 Describe the bug

When I use flex attention on one RTX 4090, I got some error.
A minimal repro:

import torch
from torch.nn.attention.flex_attention import flex_attention
flex_attention = torch.compile(flex_attention, dynamic=False)

B = 1
H = 16
S = 2048
D = 128

query = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
gradOut = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)

flex_out = flex_attention(query, key, value)
flex_out.backward(gradOut, retain_graph=True)

It works with D=64 or without compile.
(I see #132075, while this happens during backward)

Error logs

triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 131074, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

Versions

torch==2.5.0.dev20240810+cu121

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @drisspg

### Tasks

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions