@@ -4973,7 +4973,7 @@ def test_cuda(self, test_case):
49734973 # Run backwards on CPU and GPU and compare results
49744974 for _ in range (5 ):
49754975 cpu_gradOutput = cpu_output .clone ().normal_ ()
4976- gpu_gradOutput = cpu_gradOutput .type ( 'torch.cuda.FloatTensor' )
4976+ gpu_gradOutput = cpu_gradOutput .type_as ( gpu_output )
49774977 cpu_gradInput = test_case ._backward (cpu_module , cpu_input_tuple , cpu_output , cpu_gradOutput )
49784978 gpu_gradInput = test_case ._backward (gpu_module , gpu_input_tuple , gpu_output , gpu_gradOutput )
49794979 # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
@@ -5047,6 +5047,7 @@ def __init__(self, *args, **kwargs):
50475047 self .check_inplace = kwargs .get ('check_inplace' , False )
50485048 self .check_gradgrad = kwargs .get ('check_gradgrad' , True )
50495049 self .skip_double = kwargs .get ('skip_double' , False )
5050+ self .skip_half = kwargs .get ('skip_half' , False )
50505051 self .with_tf32 = kwargs .get ('with_tf32' , False )
50515052 self .tf32_precision = kwargs .get ('tf32_precision' , 0.001 )
50525053 self .test_cpu = kwargs .get ('test_cpu' , True )
@@ -5136,15 +5137,32 @@ def assert_module_parameters_are(tensor_type, device_id=None):
51365137 assert_module_parameters_are (torch .cuda .FloatTensor , 1 ) # type: ignore[attr-defined]
51375138 else :
51385139 # check that float()/double() casters work correctly
5140+ def to_type (tensor , real , complex ):
5141+ if tensor .is_complex ():
5142+ return tensor .to (complex )
5143+ elif tensor .is_floating_point ():
5144+ return tensor .to (real )
5145+ else :
5146+ return tensor
5147+
5148+ def to_half (x ):
5149+ # TODO: torch.complex32 when properly supported
5150+ return to_type (x , torch .float16 , None )
5151+
5152+ def to_single (x ):
5153+ return to_type (x , torch .float32 , torch .complex64 )
5154+
5155+ def to_double (x ):
5156+ return to_type (x , torch .float64 , torch .complex128 )
51395157
51405158 # to float
5141- input_tuple = tuple (t . float () if not isinstance ( t , torch . LongTensor ) else t for t in input_tuple )
5159+ input_tuple = tuple (to_single ( t ) for t in input_tuple )
51425160 module .float ()
51435161 module (* input_tuple )
51445162 assert_module_parameters_are (torch .FloatTensor )
51455163
51465164 # and back to double
5147- input_tuple = tuple (t . double () if not isinstance ( t , torch . LongTensor ) else t for t in input_tuple )
5165+ input_tuple = tuple (to_double ( t ) for t in input_tuple )
51485166 module .double ()
51495167 module (* input_tuple )
51505168 assert_module_parameters_are (torch .DoubleTensor )
@@ -5154,8 +5172,7 @@ def assert_module_parameters_are(tensor_type, device_id=None):
51545172 # and that float() casts parameters correctly
51555173
51565174 # to GPU0
5157- input_tuple = tuple (
5158- t .float ().cuda () if not isinstance (t , torch .LongTensor ) else t .cuda () for t in input_tuple )
5175+ input_tuple = tuple (to_single (t ).cuda () for t in input_tuple )
51595176 module .float ().cuda ()
51605177 module (* input_tuple )
51615178 assert_module_parameters_are (torch .cuda .FloatTensor , 0 ) # type: ignore[attr-defined]
@@ -5189,18 +5206,17 @@ def assert_module_parameters_are(tensor_type, device_id=None):
51895206
51905207 if not self .skip_double :
51915208 # test double()
5192- input_tuple = tuple (
5193- t .double ().cuda () if not isinstance (t , torch .LongTensor ) else t .cuda () for t in input_tuple )
5209+ input_tuple = tuple (to_double (t ).cuda () for t in input_tuple )
51945210 module .double ().cuda ()
51955211 module (* input_tuple )
51965212 assert_module_parameters_are (torch .cuda .DoubleTensor , 0 ) # type: ignore[attr-defined]
51975213
51985214 # test half()
5199- input_tuple = tuple (
5200- t . half (). cuda () if not isinstance ( t , torch . LongTensor ) else t .cuda () for t in input_tuple )
5201- module .half ().cuda ()
5202- module (* input_tuple )
5203- assert_module_parameters_are (torch .cuda .HalfTensor , 0 ) # type: ignore[attr-defined]
5215+ if not self . skip_half :
5216+ input_tuple = tuple ( to_half ( t ) .cuda () for t in input_tuple )
5217+ module .half ().cuda ()
5218+ module (* input_tuple )
5219+ assert_module_parameters_are (torch .cuda .HalfTensor , 0 ) # type: ignore[attr-defined]
52045220 torch .set_num_threads (num_threads )
52055221
52065222 def _get_target (self ):
0 commit comments