@@ -28,6 +28,7 @@ def wrapped_fn(*args, **kwargs):
2828
2929 return wrapped_fn
3030
31+
3132class TestTypePromotion (TestCase ):
3233
3334 # In-place operations don't promote.
@@ -80,44 +81,14 @@ def test_int_promotion(self, device):
8081
8182 @float_double_default_dtype
8283 def test_float_promotion (self , device ):
83- def test_promotion (dtype_float , dtype_double ):
84- a = torch .ones ([4 , 4 , 4 ], dtype = dtype_float , device = device )
85- b = torch .ones ([4 , 4 , 4 ], dtype = dtype_double , device = device )
86- c = a + b
87- self .assertEqual (c , b + b )
88- self .assertEqual (c .dtype , dtype_double )
89- c = b + a
90- self .assertEqual (c , b + b )
91- self .assertEqual (c .dtype , dtype_double )
92- test_promotion (torch .float , torch .double )
93-
94- @float_double_default_dtype
95- def test_complex_promotion (self , device ):
96- def test_promotion (dtype_float , dtype_double ):
97- a = torch .ones ([4 , 4 , 4 ], dtype = dtype_float , device = device )
98- b = torch .ones ([4 , 4 , 4 ], dtype = dtype_double , device = device )
99- c = a + b
100- self .assertEqual (c , b + b )
101- self .assertEqual (c .dtype , dtype_double )
102- c = b + a
103- self .assertEqual (c , b + b )
104- self .assertEqual (c .dtype , dtype_double )
105-
106- test_promotion (torch .complex64 , torch .complex128 )
107-
108- a = torch .randn (3 , dtype = torch .complex64 , device = device )
109- self .assertEqual ((a * 5 ).dtype , torch .complex64 )
110- # not a "wrapped number"
111- other = torch .tensor (5.5 , dtype = torch .double , device = device )
112- self .assertEqual ((a + other ).dtype , torch .complex64 )
113-
114- @float_double_default_dtype
115- def test_complex_scalar_mult_tensor_promotion (self , device ):
116- a = 1j * torch .ones (2 , device = device )
117- a = a + 1j
118- b = torch .tensor ([2j , 2j ], device = device )
119- self .assertEqual (a , b )
120- self .assertEqual (a .dtype , b .dtype )
84+ a = torch .ones ([4 , 4 , 4 ], dtype = torch .float , device = device )
85+ b = torch .ones ([4 , 4 , 4 ], dtype = torch .double , device = device )
86+ c = a + b
87+ self .assertEqual (c , b + b )
88+ self .assertEqual (c .dtype , torch .double )
89+ c = b + a
90+ self .assertEqual (c , b + b )
91+ self .assertEqual (c .dtype , torch .double )
12192
12293 @float_double_default_dtype
12394 def test_add_wrapped (self , device ):
@@ -205,17 +176,12 @@ def _get_test_tensor(self, device, dtype, remove_zeros=False):
205176 shape = [5 , 5 , 5 ]
206177 if dtype == torch .bool :
207178 tensor = torch .randint (int (remove_zeros ), 2 , shape , device = device , dtype = dtype )
208- elif dtype .is_floating_point or dtype . is_complex :
179+ elif dtype .is_floating_point :
209180 # "_th_normal_ not supported on CPUType for Half" so simpler create and convert
210181 tensor = torch .randn (shape , device = device )
211182 tensor = tensor .to (dtype )
212183 if remove_zeros :
213- tensor_abs = torch .abs (tensor )
214- if dtype == torch .complex64 :
215- tensor_abs = tensor_abs .to (torch .float )
216- elif dtype == torch .complex128 :
217- tensor_abs = tensor_abs .to (torch .double )
218- tensor [tensor_abs < 0.05 ] = 5
184+ tensor [torch .abs (tensor ) < 0.05 ] = 5
219185 else :
220186 tensor = torch .randint (- 5 if dtype .is_signed else 0 , 10 , shape , device = device , dtype = dtype )
221187 if remove_zeros :
@@ -228,9 +194,8 @@ def _get_test_tensor(self, device, dtype, remove_zeros=False):
228194 def test_many_promotions (self , device ):
229195 # Can also include half on CPU in cases where it will be promoted to a
230196 # supported dtype
231- complex_dtypes = get_all_complex_dtypes ()
232- dtypes1 = torch .testing .get_all_math_dtypes ('cuda' ) + complex_dtypes
233- dtypes2 = torch .testing .get_all_math_dtypes (device ) + complex_dtypes
197+ dtypes1 = torch .testing .get_all_math_dtypes ('cuda' )
198+ dtypes2 = torch .testing .get_all_math_dtypes (device )
234199 ops = [torch .add , torch .sub , torch .mul , torch .div , torch .rsub ]
235200 for dt1 , dt2 in itertools .product (dtypes1 , dtypes2 ):
236201 for op , non_contiguous in itertools .product (ops , [True , False ]):
0 commit comments