Skip to content

torch.compile should not recompiles when .requires_grad=True under torch.no_grad() context #131975

@gau-nernst

Description

@gau-nernst

🐛 Describe the bug

import torch

@torch.compile()
def f(a, b):
    return (a + b).clip(0)

a = torch.randn(1024, 1024)
b = torch.randn(1024, 1024)

with torch.no_grad():
    f(a, b)
    a.requires_grad_(True)
    f(a, b)

Run with TORCH_LOGS="recompiles"

V0727 08:08:44.508000 30548 torch/_dynamo/guards.py:2688] [0/1] [__recompiles] Recompiling function f in /teamspace/studios/this_studio/debug.py:3
V0727 08:08:44.508000 30548 torch/_dynamo/guards.py:2688] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0727 08:08:44.508000 30548 torch/_dynamo/guards.py:2688] [0/1] [__recompiles]     - 0/0: tensor 'L['a']' requires_grad mismatch. expected requires_grad=0

Putting torch.no_grad() inside the compiled function also yields the same results. Although this is not a bug per se, minimize recompiles should be beneficial, especially to avoid hitting recompiles limit. From my understanding, there shouldn't be any correctness problem if we don't recompile, since .requires_grad is irrelevant inside the torch.no_grad() context.

Versions

Latest torch nightly 2.5.0.dev20240726

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions