@@ -3930,9 +3930,26 @@ def check(fast_mode):
39303930 x = x .expand ((2 , 2 ))
39313931 with self .assertRaisesRegex (RuntimeError , 'The 0th input has a dimension with stride 0' ):
39323932 gradcheck (lambda x : x , (x ,), raise_exception = False , fast_mode = fast_mode )
3933+
39333934 check (fast_mode = True )
39343935 check (fast_mode = False )
39353936
3937+ @unittest .skipIf (not torch ._C .has_mkldnn , "MKL-DNN build is disabled" )
3938+ def test_gradcheck_validates_input_mkldnn (self ):
3939+ # when mkldnn inputs, forward mode testing is not allowed
3940+ # Update tolerances below to make sure the gradient match even in single precision floats
3941+ # Use the warning assert to hide the float32 warning
3942+ x = torch .ones (1 ).to_mkldnn ().requires_grad_ ()
3943+ with self .assertWarnsRegex (UserWarning , "Input #0 requires gradient and is not a double precision" ):
3944+ with self .assertRaisesRegex (ValueError , 'MKLDNN inputs are not support for forward AD gradcheck.' ):
3945+ gradcheck (lambda x : x .to_dense (), (x ,), raise_exception = False , fast_mode = False , check_forward_ad = True ,
3946+ atol = 1e-1 , rtol = 1e-1 )
3947+
3948+ with self .assertWarnsRegex (UserWarning , "Input #0 requires gradient and is not a double precision" ):
3949+ with self .assertRaisesRegex (ValueError , 'MKLDNN inputs are not support for forward AD gradcheck.' ):
3950+ gradcheck (lambda x : x .to_dense (), (x ,), raise_exception = False , fast_mode = True , check_forward_ad = True ,
3951+ atol = 1e-1 , rtol = 1e-1 )
3952+
39363953 @unittest .skipIf (not torch ._C .has_mkldnn , "MKL-DNN build is disabled" )
39373954 def test_gradcheck_test_outputs (self ):
39383955 def check (fast_mode ):
@@ -4223,6 +4240,49 @@ def fn2(x):
42234240 check (fast_mode = True )
42244241 check (fast_mode = False )
42254242
4243+ def test_gradcheck_forward_ad (self ):
4244+ def fn (x , y ):
4245+ return x + y , y
4246+
4247+ def bad_fn (x , y ):
4248+ # Hacky way to check if we're currently inside a forward ad level
4249+ is_running_forward_ad = fwAD ._current_level >= 0
4250+
4251+ if is_running_forward_ad :
4252+ y_p , y_d = fwAD .unpack_dual (y )
4253+ y = fwAD .make_dual (y_p , y_d * 1.1 )
4254+
4255+ return x + y , y
4256+
4257+ err_msg = "Jacobian computed with forward mode mismatch for output 0 with respect to input 1"
4258+
4259+ for fast_mode in [True , False ]:
4260+ # Test for all inputs and outputs being real
4261+ x = torch .rand (2 , dtype = torch .double , requires_grad = True )
4262+ y = torch .rand (2 , dtype = torch .double , requires_grad = True )
4263+
4264+ gradcheck (fn , (x , y ), check_forward_ad = True , fast_mode = fast_mode )
4265+ with self .assertRaisesRegex (RuntimeError , err_msg ):
4266+ gradcheck (bad_fn , (x , y ), check_forward_ad = True , fast_mode = fast_mode )
4267+
4268+ def basic_mul (x ):
4269+ return torch .view_as_real (x * 1j )
4270+ gradcheck (basic_mul , x , check_forward_ad = True , fast_mode = fast_mode )
4271+
4272+ # Test for one input and one output being complex
4273+ x = torch .rand (2 , dtype = torch .cdouble , requires_grad = True )
4274+
4275+ gradcheck (fn , (x , y ), check_forward_ad = True , fast_mode = fast_mode )
4276+ with self .assertRaisesRegex (RuntimeError , err_msg ):
4277+ gradcheck (bad_fn , (x , y ), check_forward_ad = True , fast_mode = fast_mode )
4278+
4279+ # Test for all inputs and outputs being complex
4280+ y = torch .rand (2 , dtype = torch .cdouble , requires_grad = True )
4281+
4282+ gradcheck (fn , (x , y ), check_forward_ad = True , fast_mode = fast_mode )
4283+ with self .assertRaisesRegex (RuntimeError , err_msg ):
4284+ gradcheck (bad_fn , (x , y ), check_forward_ad = True , fast_mode = fast_mode )
4285+
42264286 def test_version_counter (self ):
42274287 x = torch .randn (1 , 2 )
42284288
0 commit comments