Repro
from functools import partial
import torch
import torchdynamo
print = partial(print, flush=True)
def reversible(x):
print("Hello world") # Cause graph break so inline fails
return torch.sin(torch.cos(x))
def fn(x):
torch._C._set_grad_enabled(False)
with torch.enable_grad():
a = torch.sin(x)
b = reversible(a)
c = torch.sigmoid(b)
c.sum().backward()
return x.grad
x = torch.randn(4, requires_grad=True)
x.grad = None
ref = fn(x)
print("Eager done")
# torchdynamo.config.trace = True
# torchdynamo.config.debug = True
x.grad = None
with torchdynamo.optimize("eager"):
res = fn(x)
print(res)
This fails with the following error
Traceback (most recent call last):
File "with_test.py", line 39, in <module>
res = fn(x)
File "with_test.py", line 17, in fn
def fn(x):
File "with_test.py", line 17, in fn
def fn(x):
File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/_tensor.py", line 399, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Following is the bytecode
ORIGINAL BYTECODE fn with_test.py 13
14 0 LOAD_GLOBAL 0 (torch)
2 LOAD_ATTR 1 (_C)
4 LOAD_METHOD 2 (_set_grad_enabled)
6 LOAD_CONST 1 (False)
8 CALL_METHOD 1
10 POP_TOP
16 12 LOAD_GLOBAL 0 (torch)
14 LOAD_METHOD 3 (enable_grad)
16 CALL_METHOD 0
18 SETUP_WITH 60 (to 80)
20 POP_TOP
17 22 LOAD_GLOBAL 0 (torch)
24 LOAD_METHOD 4 (sin)
26 LOAD_FAST 0 (x)
28 CALL_METHOD 1
30 STORE_FAST 1 (a)
18 32 LOAD_GLOBAL 5 (reversible)
34 LOAD_FAST 1 (a)
36 CALL_FUNCTION 1
38 STORE_FAST 2 (b)
19 40 LOAD_GLOBAL 0 (torch)
42 LOAD_METHOD 6 (sigmoid)
44 LOAD_FAST 2 (b)
46 CALL_METHOD 1
48 STORE_FAST 3 (c)
20 50 LOAD_FAST 3 (c)
52 LOAD_METHOD 7 (sum)
54 CALL_METHOD 0
56 LOAD_METHOD 8 (backward)
58 CALL_METHOD 0
60 POP_TOP
21 62 LOAD_FAST 0 (x)
64 LOAD_ATTR 9 (grad)
66 POP_BLOCK
68 ROT_TWO
70 BEGIN_FINALLY
72 WITH_CLEANUP_START
74 WITH_CLEANUP_FINISH
76 POP_FINALLY 0
78 RETURN_VALUE
>> 80 WITH_CLEANUP_START
82 WITH_CLEANUP_FINISH
84 END_FINALLY
86 LOAD_CONST 0 (None)
88 RETURN_VALUE
MODIFIED BYTECODE
13 0 LOAD_GLOBAL 11 (__compiled_fn_0)
2 LOAD_FAST 0 (x)
4 CALL_FUNCTION 1
6 STORE_FAST 4 (___graph_out_0)
8 LOAD_GLOBAL 10 (__import_torch)
10 LOAD_ATTR 3 (enable_grad)
12 LOAD_GLOBAL 5 (reversible)
14 LOAD_FAST 4 (___graph_out_0)
16 LOAD_CONST 2 (0)
18 BINARY_SUBSCR
20 CALL_FUNCTION 1
22 LOAD_GLOBAL 12 (__resume_at_38_1)
24 ROT_THREE
26 LOAD_FAST 0 (x)
28 CALL_FUNCTION 3
30 RETURN_VALUE
In the modified bytecode, call to the reversible function is not happening inside the with context. Therefore, reversible function is called with grad flag disabled and it triggers the above error.
Note that the __compiled_fn_0 and the resume function, both correctly keep track of the with context. We are missing the with context only for the instruction(s) that cause the graph break (CALL_FUNCTION here).
Also note that Dynamo has a special handling for no_grad and enable_grad - https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/symbolic_convert.py#L421. For any other type of context manager, we will just break the graph on SETUP_WITH. So, a quick fix (but with poorer coverage) is to always break on SETUP_WITH.
@jansel suggested to hand-write the python bytecode for try .. except block and conditionally insert at https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/symbolic_convert.py#L137
>>> import dis
>>> def foo():
... set_grad_true()
... try:
... user_inst()
... finally:
... set_grad_false()
...
>>> dis.dis(foo)
2 0 LOAD_GLOBAL 0 (set_grad_true)
2 CALL_FUNCTION 0
4 POP_TOP
3 6 SETUP_FINALLY 16 (to 24)
4 8 LOAD_GLOBAL 1 (user_inst)
10 CALL_FUNCTION 0
12 POP_TOP
14 POP_BLOCK
6 16 LOAD_GLOBAL 2 (set_grad_false)
18 CALL_FUNCTION 0
20 POP_TOP
22 JUMP_FORWARD 8 (to 32)
>> 24 LOAD_GLOBAL 2 (set_grad_false)
26 CALL_FUNCTION 0
28 POP_TOP
30 RERAISE
>> 32 LOAD_CONST 0 (None)
34 RETURN_VALUE
>>>
My initial experience is that it is little more tedious as we have to setup the resume call and its arguments as well.
Repro
This fails with the following error
Following is the bytecode
In the modified bytecode, call to the reversible function is not happening inside the
withcontext. Therefore,reversiblefunction is called withgradflag disabled and it triggers the above error.Note that the
__compiled_fn_0and the resume function, both correctly keep track of thewithcontext. We are missing thewithcontext only for the instruction(s) that cause the graph break (CALL_FUNCTIONhere).Also note that Dynamo has a special handling for
no_gradandenable_grad- https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/symbolic_convert.py#L421. For any other type of context manager, we will just break the graph onSETUP_WITH. So, a quick fix (but with poorer coverage) is to always break onSETUP_WITH.@jansel suggested to hand-write the python bytecode for try .. except block and conditionally insert at https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/symbolic_convert.py#L137
My initial experience is that it is little more tedious as we have to setup the resume call and its arguments as well.