🐛 Describe the bug
import torch
@torch.compile()
def f(a, b):
return (a + b).clip(0)
a = torch.randn(1024, 1024)
b = torch.randn(1024, 1024)
with torch.no_grad():
f(a, b)
a.requires_grad_(True)
f(a, b)
Run with TORCH_LOGS="recompiles"
V0727 08:08:44.508000 30548 torch/_dynamo/guards.py:2688] [0/1] [__recompiles] Recompiling function f in /teamspace/studios/this_studio/debug.py:3
V0727 08:08:44.508000 30548 torch/_dynamo/guards.py:2688] [0/1] [__recompiles] triggered by the following guard failure(s):
V0727 08:08:44.508000 30548 torch/_dynamo/guards.py:2688] [0/1] [__recompiles] - 0/0: tensor 'L['a']' requires_grad mismatch. expected requires_grad=0
Putting torch.no_grad() inside the compiled function also yields the same results. Although this is not a bug per se, minimize recompiles should be beneficial, especially to avoid hitting recompiles limit. From my understanding, there shouldn't be any correctness problem if we don't recompile, since .requires_grad is irrelevant inside the torch.no_grad() context.
Versions
Latest torch nightly 2.5.0.dev20240726
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames
🐛 Describe the bug
Run with
TORCH_LOGS="recompiles"Putting
torch.no_grad()inside the compiled function also yields the same results. Although this is not a bug per se, minimize recompiles should be beneficial, especially to avoid hitting recompiles limit. From my understanding, there shouldn't be any correctness problem if we don't recompile, since.requires_gradis irrelevant inside thetorch.no_grad()context.Versions
Latest torch nightly 2.5.0.dev20240726
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames