@@ -9118,6 +9118,45 @@ def _inps():
91189118 self .assertEqual (ref_inps_after_fw , inps_after_fw )
91199119 self .assertEqual (ref_inps_after_bw , inps_after_bw )
91209120
9121+ def test_mutations_in_bw_requires_grad_input (self ):
9122+ # Inplace mutation of a requires_grad=True forward input during the
9123+ # backward pass should not trigger the check_inplace error
9124+ class AF (torch .autograd .Function ):
9125+ @staticmethod
9126+ def forward (ctx , dummy , inplace_tensor ):
9127+ ctx .save_for_backward (inplace_tensor )
9128+ return dummy .clone ()
9129+
9130+ @staticmethod
9131+ def backward (ctx , grad_output ):
9132+ (inplace_tensor ,) = ctx .saved_tensors
9133+ inplace_tensor .mul_ (2.0 )
9134+ return grad_output , None
9135+
9136+ def fn (dummy , inplace_tensor ):
9137+ return AF .apply (dummy , inplace_tensor )
9138+
9139+ def _inps ():
9140+ dummy = torch .zeros ((2 ,), requires_grad = True )
9141+ # requires_grad=True is what triggered the bug
9142+ inplace_tensor = torch .ones ((2 ,), requires_grad = True )
9143+ return dummy , inplace_tensor
9144+
9145+ inps = _inps ()
9146+ out = fn (* inps )
9147+ ref_inps_after_fw = [x .clone ().detach () for x in inps ]
9148+ out .sum ().backward ()
9149+ ref_inps_after_bw = [x .clone ().detach () for x in inps ]
9150+
9151+ inps = _inps ()
9152+ out = torch .compile (fn , backend = "aot_eager" , fullgraph = True )(* inps )
9153+ inps_after_fw = [x .clone ().detach () for x in inps ]
9154+ out .sum ().backward ()
9155+ inps_after_bw = [x .clone ().detach () for x in inps ]
9156+
9157+ self .assertEqual (ref_inps_after_fw , inps_after_fw )
9158+ self .assertEqual (ref_inps_after_bw , inps_after_bw )
9159+
91219160 def test_mutation_of_input_in_fw_and_bw (self ):
91229161 class AF (torch .autograd .Function ):
91239162 @staticmethod
0 commit comments