🐛 Describe the bug
When I use flex attention on one RTX 4090, I got some error.
A minimal repro:
import torch
from torch.nn.attention.flex_attention import flex_attention
flex_attention = torch.compile(flex_attention, dynamic=False)
B = 1
H = 16
S = 2048
D = 128
query = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
gradOut = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
flex_out = flex_attention(query, key, value)
flex_out.backward(gradOut, retain_graph=True)
It works with D=64 or without compile.
(I see #132075, while this happens during backward)
Error logs
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 131074, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
Versions
torch==2.5.0.dev20240810+cu121
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @drisspg
🐛 Describe the bug
When I use flex attention on one RTX 4090, I got some error.
A minimal repro:
It works with D=64 or without compile.
(I see #132075, while this happens during backward)
Error logs
Versions
torch==2.5.0.dev20240810+cu121
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @drisspg