11import torch
22import unittest
33from torch .testing ._internal .common_utils import TestCase , run_tests , TEST_WITH_ROCM , TEST_WITH_SLOW
4- from torch .testing ._internal .common_device_type import instantiate_device_type_tests , dtypes , skipCUDAIfRocm
4+ from torch .testing ._internal .common_device_type import \
5+ (instantiate_device_type_tests , dtypes , skipCUDAIfRocm , ops )
56from torch ._six import inf , nan
7+ from torch .testing ._internal .common_methods_invocations import foreach_unary_op_db
68
79# Includes some values such that N * N won't be a multiple of 4,
810# which should ensure we test the vectorized and non-vectorized
@@ -17,39 +19,6 @@ class TestForeach(TestCase):
1719 (torch ._foreach_div , torch ._foreach_div_ , torch .div ),
1820 ]
1921
20- unary_ops = [
21- # foreach_op, foreach_op_, torch_op, bf16, complex64/128
22- (torch ._foreach_sqrt , torch ._foreach_sqrt_ , torch .sqrt , True , True ),
23- (torch ._foreach_exp , torch ._foreach_exp_ , torch .exp , True , True ),
24- (torch ._foreach_acos , torch ._foreach_acos_ , torch .acos , False , True ),
25- (torch ._foreach_asin , torch ._foreach_asin_ , torch .asin , False , True ),
26- (torch ._foreach_atan , torch ._foreach_atan_ , torch .atan , False , True ),
27- (torch ._foreach_cos , torch ._foreach_cos_ , torch .cos , True , True ),
28- (torch ._foreach_cosh , torch ._foreach_cosh_ , torch .cosh , False , True ),
29- (torch ._foreach_log , torch ._foreach_log_ , torch .log , True , True ),
30- (torch ._foreach_log10 , torch ._foreach_log10_ , torch .log10 , True , True ),
31- (torch ._foreach_log2 , torch ._foreach_log2_ , torch .log2 , True , True ),
32- (torch ._foreach_neg , torch ._foreach_neg_ , torch .neg , True , True ),
33- (torch ._foreach_tan , torch ._foreach_tan_ , torch .tan , False , True ),
34- (torch ._foreach_tanh , torch ._foreach_tanh_ , torch .tanh , True , True ),
35- (torch ._foreach_sin , torch ._foreach_sin_ , torch .sin , False , True ),
36- (torch ._foreach_sinh , torch ._foreach_sinh_ , torch .sinh , False , True ),
37- (torch ._foreach_ceil , torch ._foreach_ceil_ , torch .ceil , False , False ),
38- (torch ._foreach_erf , torch ._foreach_erf_ , torch .erf , True , False ),
39- (torch ._foreach_erfc , torch ._foreach_erfc_ , torch .erfc , False , False ),
40- (torch ._foreach_expm1 , torch ._foreach_expm1_ , torch .expm1 , False , False ),
41- (torch ._foreach_floor , torch ._foreach_floor_ , torch .floor , False , False ),
42- (torch ._foreach_log1p , torch ._foreach_log1p_ , torch .log1p , True , False ),
43- (torch ._foreach_round , torch ._foreach_round_ , torch .round , False , False ),
44- (torch ._foreach_frac , torch ._foreach_frac_ , torch .frac , False , False ),
45- (torch ._foreach_reciprocal , torch ._foreach_reciprocal_ , torch .reciprocal , True , True ),
46- (torch ._foreach_sigmoid , torch ._foreach_sigmoid_ , torch .sigmoid , True , False ),
47- (torch ._foreach_trunc , torch ._foreach_trunc_ , torch .trunc , False , False ),
48-
49- # See test_abs
50- # (torch._foreach_abs, torch._foreach_abs_, torch.abs, True, True),
51- ]
52-
5322 def _get_test_data (self , device , dtype , N ):
5423 if dtype in [torch .bfloat16 , torch .bool , torch .float16 ]:
5524 tensors = [torch .randn (N , N , device = device ).to (dtype ) for _ in range (N )]
@@ -157,90 +126,27 @@ def _test_bin_op_list_alpha(self, device, dtype, foreach_op, foreach_op_, torch_
157126 else :
158127 self .assertEqual (tensors1 , expected )
159128
160- #
161- # Unary ops
162- #
163- @dtypes (* (torch .testing .floating_and_complex_types_and (torch .bfloat16 , torch .half )))
164- def test_unary_ops (self , device , dtype ):
165- for fe_op , fe_op_ , torch_op , support_bfloat16 , support_complex in self .unary_ops :
166- for N in N_values :
167- tensors1 = self ._get_test_data (device , dtype , N )
168- # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16.
169- control_dtype = torch .float32 if (self .device_type == 'cuda' and
170- (dtype is torch .float16 or dtype is torch .bfloat16 )) else dtype
171-
172- if self .device_type == 'cpu' and dtype == torch .half and torch_op not in [torch .neg , torch .frac , torch .reciprocal ]:
173- with self .assertRaisesRegex (RuntimeError , r"not implemented for \'Half\'" ):
174- expected = [torch_op (tensors1 [i ]) for i in range (N )]
175-
176- with self .assertRaisesRegex (RuntimeError , r"not implemented for \'Half\'" ):
177- res = fe_op (tensors1 )
178- break
179-
180- if dtype == torch .bfloat16 and not support_bfloat16 :
181- if self .device_type == 'cuda' or torch_op in [torch .sinh , torch .cosh ]:
182- with self .assertRaisesRegex (RuntimeError , r"not implemented for \'BFloat16\'" ):
183- expected = [torch_op (tensors1 [i ]) for i in range (N )]
184-
185- with self .assertRaisesRegex (RuntimeError , r"not implemented for \'BFloat16\'" ):
186- res = fe_op (tensors1 )
187- break
188-
189- if dtype in [torch .complex64 , torch .complex128 ] and not support_complex :
190- if not (self .device_type == 'cpu' and torch_op in [torch .sigmoid ]):
191- # not using assertRaisesRegex due to different error messages
192- with self .assertRaises (RuntimeError ):
193- expected = [torch_op (tensors1 [i ]) for i in range (N )]
194-
195- with self .assertRaises (RuntimeError ):
196- res = fe_op (tensors1 )
197- break
198-
199- expected = [torch_op (tensors1 [i ].to (dtype = control_dtype )).to (dtype = dtype ) for i in range (N )]
200- res = fe_op (tensors1 )
201- if (dtype is torch .float16 or dtype is torch .bfloat16 ) and TEST_WITH_ROCM :
202- self .assertEqual (res , expected , atol = 1.e-3 , rtol = self .dtype_precisions [dtype ][0 ])
203-
204- fe_op_ (tensors1 )
205- self .assertEqual (res , tensors1 )
206- else :
207- self .assertEqual (res , expected )
208-
209- fe_op_ (tensors1 )
210- self .assertEqual (res , tensors1 )
211-
212- # Separate test for abs due to a lot of special cases
213- # Absolute value of a complex number a + bj is defined as sqrt(a^2 + b^2), i.e. a floating point
214- @dtypes (* torch .testing .get_all_dtypes ())
215- def test_abs (self , device , dtype ):
129+ @ops (foreach_unary_op_db )
130+ def test_unary (self , device , dtype , op ):
216131 for N in N_values :
217- tensors1 = self ._get_test_data (device , dtype , N )
218- # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16.
219- control_dtype = torch .float32 if (self .device_type == 'cuda' and
220- (dtype is torch .float16 or dtype is torch .bfloat16 )) else dtype
221-
222- if dtype == torch .bool and self .device_type == 'cpu' :
223- with self .assertRaisesRegex (RuntimeError , r"not implemented for" ):
224- expected = [torch .abs (tensors1 [i ].to (dtype = control_dtype )).to (dtype = dtype ) for i in range (N )]
225- continue
226-
227- expected = [torch .abs (tensors1 [i ].to (dtype = control_dtype )).to (dtype = dtype ) for i in range (N )]
228- res = torch ._foreach_abs (tensors1 )
229- if (dtype is torch .float16 or dtype is torch .bfloat16 ) and TEST_WITH_ROCM :
230- self .assertEqual (res , expected , atol = 1.e-3 , rtol = self .dtype_precisions [dtype ][0 ])
231-
232- torch ._foreach_abs_ (tensors1 )
233- self .assertEqual (res , tensors1 )
132+ tensors = op .sample_inputs (device , dtype , N )
133+ expected = [op .ref (t ) for t in tensors ]
134+
135+ method = op .get_method ()
136+ inplace = op .get_inplace ()
137+ actual = method (tensors )
138+ self .assertEqual (expected , actual )
139+
140+ if op .safe_casts_outputs and dtype in torch .testing .integral_types_and (torch .bool ):
141+ with self .assertRaisesRegex (RuntimeError , "can't be cast to the desired output type" ):
142+ inplace (tensors )
143+ elif dtype in [torch .complex64 , torch .complex128 ] and inplace == torch ._foreach_abs_ :
144+ # Special case for abs
145+ with self .assertRaisesRegex (RuntimeError , r"In-place abs is not supported for complex tensors." ):
146+ inplace (tensors )
234147 else :
235- expected = [torch .abs (tensors1 [i ]) for i in range (N )]
236- self .assertEqual (res , expected )
237-
238- if dtype in [torch .complex64 , torch .complex128 ]:
239- with self .assertRaisesRegex (RuntimeError , r"In-place abs is not supported for complex tensors." ):
240- torch ._foreach_abs_ (tensors1 )
241- else :
242- torch ._foreach_abs_ (tensors1 )
243- self .assertEqual (res , tensors1 )
148+ inplace (tensors )
149+ self .assertEqual (tensors , actual )
244150
245151 #
246152 # Pointwise ops
@@ -294,7 +200,6 @@ def test_min_max(self, device, dtype):
294200 res_min = torch ._foreach_minimum (tensors1 , tensors2 )
295201 self .assertEqual (res_min , expected_min )
296202
297-
298203 @dtypes (* (torch .testing .get_all_fp_dtypes (include_half = True , include_bfloat16 = False )))
299204 def test_max_min_float_inf_nan (self , device , dtype ):
300205 a = [
0 commit comments