Skip to content

Commit cc7a28d

Browse files
Iurii Zdebskyifacebook-github-bot
authored andcommitted
Refactor Unary Ops tests (#49712)
Summary: Pull Request resolved: #49712 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D25673712 Pulled By: izdeby fbshipit-source-id: 4420d5d129026195097d914e410b75b144bea795
1 parent 645a3e9 commit cc7a28d

3 files changed

Lines changed: 159 additions & 118 deletions

File tree

aten/src/ATen/native/cuda/ForeachUnaryOp.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, reciprocal, Reciprocal)
211211

212212
std::vector<Tensor> foreach_tensor_neg_cuda(TensorList tensors) {
213213
check_foreach_api_restrictions(tensors);
214+
TORCH_CHECK(tensors[0].scalar_type() != kBool,
215+
"_foreach_neg: There is a bool tensor in the passed-in TensorList. "
216+
"Negation on a bool tensor is not supported. If you are trying to invert a mask, please use the `~`"
217+
"or `logical_not()` operator on the individual tensors instead.");
214218

215219
if (!can_use_fast_route(tensors)) {
216220
return at::native::foreach_tensor_neg_slow(tensors);
@@ -221,6 +225,10 @@ std::vector<Tensor> foreach_tensor_neg_cuda(TensorList tensors) {
221225

222226
void foreach_tensor_neg_cuda_(TensorList tensors) {
223227
check_foreach_api_restrictions(tensors);
228+
TORCH_CHECK(tensors[0].scalar_type() != kBool,
229+
"_foreach_neg: There is a bool tensor in the passed-in TensorList. "
230+
"Negation on a bool tensor is not supported. If you are trying to invert a mask, please use the `~`"
231+
"or `logical_not()` operator on the individual tensors instead.");
224232

225233
if (!can_use_fast_route(tensors)) {
226234
return at::native::foreach_tensor_neg_slow_(tensors);

test/test_foreach.py

Lines changed: 22 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import torch
22
import unittest
33
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW
4-
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, skipCUDAIfRocm
4+
from torch.testing._internal.common_device_type import \
5+
(instantiate_device_type_tests, dtypes, skipCUDAIfRocm, ops)
56
from torch._six import inf, nan
7+
from torch.testing._internal.common_methods_invocations import foreach_unary_op_db
68

79
# Includes some values such that N * N won't be a multiple of 4,
810
# which should ensure we test the vectorized and non-vectorized
@@ -17,39 +19,6 @@ class TestForeach(TestCase):
1719
(torch._foreach_div, torch._foreach_div_, torch.div),
1820
]
1921

20-
unary_ops = [
21-
# foreach_op, foreach_op_, torch_op, bf16, complex64/128
22-
(torch._foreach_sqrt, torch._foreach_sqrt_, torch.sqrt, True , True),
23-
(torch._foreach_exp, torch._foreach_exp_, torch.exp, True, True),
24-
(torch._foreach_acos, torch._foreach_acos_, torch.acos, False, True),
25-
(torch._foreach_asin, torch._foreach_asin_, torch.asin, False, True),
26-
(torch._foreach_atan, torch._foreach_atan_, torch.atan, False, True),
27-
(torch._foreach_cos, torch._foreach_cos_, torch.cos, True, True),
28-
(torch._foreach_cosh, torch._foreach_cosh_, torch.cosh, False, True),
29-
(torch._foreach_log, torch._foreach_log_, torch.log, True, True),
30-
(torch._foreach_log10, torch._foreach_log10_, torch.log10, True, True),
31-
(torch._foreach_log2, torch._foreach_log2_, torch.log2, True, True),
32-
(torch._foreach_neg, torch._foreach_neg_, torch.neg, True, True),
33-
(torch._foreach_tan, torch._foreach_tan_, torch.tan, False, True),
34-
(torch._foreach_tanh, torch._foreach_tanh_, torch.tanh, True, True),
35-
(torch._foreach_sin, torch._foreach_sin_, torch.sin, False, True),
36-
(torch._foreach_sinh, torch._foreach_sinh_, torch.sinh, False, True),
37-
(torch._foreach_ceil, torch._foreach_ceil_, torch.ceil, False, False),
38-
(torch._foreach_erf, torch._foreach_erf_, torch.erf, True, False),
39-
(torch._foreach_erfc, torch._foreach_erfc_, torch.erfc, False, False),
40-
(torch._foreach_expm1, torch._foreach_expm1_, torch.expm1, False, False),
41-
(torch._foreach_floor, torch._foreach_floor_, torch.floor, False, False),
42-
(torch._foreach_log1p, torch._foreach_log1p_, torch.log1p, True, False),
43-
(torch._foreach_round, torch._foreach_round_, torch.round, False, False),
44-
(torch._foreach_frac, torch._foreach_frac_, torch.frac, False, False),
45-
(torch._foreach_reciprocal, torch._foreach_reciprocal_, torch.reciprocal, True, True),
46-
(torch._foreach_sigmoid, torch._foreach_sigmoid_, torch.sigmoid, True, False),
47-
(torch._foreach_trunc, torch._foreach_trunc_, torch.trunc, False, False),
48-
49-
# See test_abs
50-
# (torch._foreach_abs, torch._foreach_abs_, torch.abs, True, True),
51-
]
52-
5322
def _get_test_data(self, device, dtype, N):
5423
if dtype in [torch.bfloat16, torch.bool, torch.float16]:
5524
tensors = [torch.randn(N, N, device=device).to(dtype) for _ in range(N)]
@@ -157,90 +126,27 @@ def _test_bin_op_list_alpha(self, device, dtype, foreach_op, foreach_op_, torch_
157126
else:
158127
self.assertEqual(tensors1, expected)
159128

160-
#
161-
# Unary ops
162-
#
163-
@dtypes(*(torch.testing.floating_and_complex_types_and(torch.bfloat16, torch.half)))
164-
def test_unary_ops(self, device, dtype):
165-
for fe_op, fe_op_, torch_op, support_bfloat16, support_complex in self.unary_ops:
166-
for N in N_values:
167-
tensors1 = self._get_test_data(device, dtype, N)
168-
# Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16.
169-
control_dtype = torch.float32 if (self.device_type == 'cuda' and
170-
(dtype is torch.float16 or dtype is torch.bfloat16)) else dtype
171-
172-
if self.device_type == 'cpu' and dtype == torch.half and torch_op not in [torch.neg, torch.frac, torch.reciprocal]:
173-
with self.assertRaisesRegex(RuntimeError, r"not implemented for \'Half\'"):
174-
expected = [torch_op(tensors1[i]) for i in range(N)]
175-
176-
with self.assertRaisesRegex(RuntimeError, r"not implemented for \'Half\'"):
177-
res = fe_op(tensors1)
178-
break
179-
180-
if dtype == torch.bfloat16 and not support_bfloat16:
181-
if self.device_type == 'cuda' or torch_op in [torch.sinh, torch.cosh]:
182-
with self.assertRaisesRegex(RuntimeError, r"not implemented for \'BFloat16\'"):
183-
expected = [torch_op(tensors1[i]) for i in range(N)]
184-
185-
with self.assertRaisesRegex(RuntimeError, r"not implemented for \'BFloat16\'"):
186-
res = fe_op(tensors1)
187-
break
188-
189-
if dtype in [torch.complex64, torch.complex128] and not support_complex:
190-
if not (self.device_type == 'cpu' and torch_op in [torch.sigmoid]):
191-
# not using assertRaisesRegex due to different error messages
192-
with self.assertRaises(RuntimeError):
193-
expected = [torch_op(tensors1[i]) for i in range(N)]
194-
195-
with self.assertRaises(RuntimeError):
196-
res = fe_op(tensors1)
197-
break
198-
199-
expected = [torch_op(tensors1[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)]
200-
res = fe_op(tensors1)
201-
if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM:
202-
self.assertEqual(res, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0])
203-
204-
fe_op_(tensors1)
205-
self.assertEqual(res, tensors1)
206-
else:
207-
self.assertEqual(res, expected)
208-
209-
fe_op_(tensors1)
210-
self.assertEqual(res, tensors1)
211-
212-
# Separate test for abs due to a lot of special cases
213-
# Absolute value of a complex number a + bj is defined as sqrt(a^2 + b^2), i.e. a floating point
214-
@dtypes(*torch.testing.get_all_dtypes())
215-
def test_abs(self, device, dtype):
129+
@ops(foreach_unary_op_db)
130+
def test_unary(self, device, dtype, op):
216131
for N in N_values:
217-
tensors1 = self._get_test_data(device, dtype, N)
218-
# Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16.
219-
control_dtype = torch.float32 if (self.device_type == 'cuda' and
220-
(dtype is torch.float16 or dtype is torch.bfloat16)) else dtype
221-
222-
if dtype == torch.bool and self.device_type == 'cpu':
223-
with self.assertRaisesRegex(RuntimeError, r"not implemented for"):
224-
expected = [torch.abs(tensors1[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)]
225-
continue
226-
227-
expected = [torch.abs(tensors1[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)]
228-
res = torch._foreach_abs(tensors1)
229-
if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM:
230-
self.assertEqual(res, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0])
231-
232-
torch._foreach_abs_(tensors1)
233-
self.assertEqual(res, tensors1)
132+
tensors = op.sample_inputs(device, dtype, N)
133+
expected = [op.ref(t) for t in tensors]
134+
135+
method = op.get_method()
136+
inplace = op.get_inplace()
137+
actual = method(tensors)
138+
self.assertEqual(expected, actual)
139+
140+
if op.safe_casts_outputs and dtype in torch.testing.integral_types_and(torch.bool):
141+
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
142+
inplace(tensors)
143+
elif dtype in [torch.complex64, torch.complex128] and inplace == torch._foreach_abs_:
144+
# Special case for abs
145+
with self.assertRaisesRegex(RuntimeError, r"In-place abs is not supported for complex tensors."):
146+
inplace(tensors)
234147
else:
235-
expected = [torch.abs(tensors1[i]) for i in range(N)]
236-
self.assertEqual(res, expected)
237-
238-
if dtype in [torch.complex64, torch.complex128]:
239-
with self.assertRaisesRegex(RuntimeError, r"In-place abs is not supported for complex tensors."):
240-
torch._foreach_abs_(tensors1)
241-
else:
242-
torch._foreach_abs_(tensors1)
243-
self.assertEqual(res, tensors1)
148+
inplace(tensors)
149+
self.assertEqual(tensors, actual)
244150

245151
#
246152
# Pointwise ops
@@ -294,7 +200,6 @@ def test_min_max(self, device, dtype):
294200
res_min = torch._foreach_minimum(tensors1, tensors2)
295201
self.assertEqual(res_min, expected_min)
296202

297-
298203
@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False)))
299204
def test_max_min_float_inf_nan(self, device, dtype):
300205
a = [

torch/testing/_internal/common_methods_invocations.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,47 @@ def __init__(self,
987987
**kwargs)
988988
self.ref = ref
989989

990+
def sample_inputs_foreach(self, device, dtype, N):
991+
tensors = [make_tensor((N, N), device, dtype) for _ in range(N)]
992+
return tensors
993+
994+
995+
def get_foreach_method_names(name):
996+
# get torch inplace reference function
997+
method_name = "_foreach_" + name
998+
method_name_inplace = "_foreach_" + name + "_"
999+
1000+
method = getattr(torch, method_name, None)
1001+
method_inplace = getattr(torch, method_name_inplace, None)
1002+
1003+
ref = getattr(torch.Tensor, name, None)
1004+
1005+
return method, method_inplace, ref
1006+
1007+
class ForeachUnaryFuncInfo(OpInfo):
1008+
"""Early version of a specialized OpInfo for foreach unary functions"""
1009+
def __init__(self,
1010+
name,
1011+
dtypes=floating_and_complex_types(),
1012+
dtypesIfCPU=all_types_and_complex(),
1013+
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
1014+
dtypesIfROCM=None,
1015+
safe_casts_outputs=True,
1016+
sample_inputs_func=sample_inputs_foreach,
1017+
**kwargs):
1018+
super(ForeachUnaryFuncInfo, self).__init__("_foreach_" + name,
1019+
dtypes=dtypes,
1020+
dtypesIfCPU=dtypesIfCPU,
1021+
dtypesIfCUDA=dtypesIfCUDA,
1022+
dtypesIfROCM=dtypesIfROCM,
1023+
safe_casts_outputs=safe_casts_outputs,
1024+
sample_inputs_func=sample_inputs_func,
1025+
**kwargs)
1026+
1027+
foreach_method, foreach_method_inplace, torch_ref_method = get_foreach_method_names(name)
1028+
self.method_variant = foreach_method
1029+
self.inplace_variant = foreach_method_inplace
1030+
self.ref = torch_ref_method
9901031

9911032
class HermitianOpInfo(OpInfo):
9921033
"""Operator information for Hermitian functions
@@ -1561,7 +1602,6 @@ def _make_tensor_helper(shape, low=None, high=None):
15611602

15621603
return samples
15631604

1564-
15651605
def sample_inputs_lerp(op_info, device, dtype, requires_grad):
15661606
def _make_tensor_helper(shape, low=None, high=None):
15671607
return make_tensor(shape, device, dtype, low=low, high=high, requires_grad=requires_grad)
@@ -1606,6 +1646,94 @@ def _make_tensor_helper(shape, low=None, high=None):
16061646

16071647
return samples
16081648

1649+
foreach_unary_op_db: List[OpInfo] = [
1650+
ForeachUnaryFuncInfo('exp'),
1651+
ForeachUnaryFuncInfo('acos'),
1652+
ForeachUnaryFuncInfo('asin'),
1653+
ForeachUnaryFuncInfo('atan'),
1654+
ForeachUnaryFuncInfo('cos'),
1655+
ForeachUnaryFuncInfo('cosh'),
1656+
ForeachUnaryFuncInfo('log'),
1657+
ForeachUnaryFuncInfo('log10'),
1658+
ForeachUnaryFuncInfo('log2'),
1659+
ForeachUnaryFuncInfo('tan'),
1660+
ForeachUnaryFuncInfo('tanh'),
1661+
ForeachUnaryFuncInfo('sin'),
1662+
ForeachUnaryFuncInfo('sinh'),
1663+
1664+
ForeachUnaryFuncInfo('neg',
1665+
dtypes=all_types_and_complex(),
1666+
dtypesIfCPU=all_types_and_complex(),
1667+
dtypesIfCUDA=all_types_and_complex(),
1668+
sample_inputs_func=sample_inputs_foreach,
1669+
safe_casts_outputs=False),
1670+
1671+
ForeachUnaryFuncInfo('sqrt',
1672+
dtypes=floating_types(),
1673+
dtypesIfCPU=floating_and_complex_types_and(torch.bfloat16),
1674+
dtypesIfCUDA=floating_and_complex_types_and(torch.half)),
1675+
1676+
ForeachUnaryFuncInfo('ceil',
1677+
dtypes=floating_types(),
1678+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1679+
dtypesIfCUDA=floating_types_and(torch.half)),
1680+
1681+
ForeachUnaryFuncInfo('erf',
1682+
dtypes=floating_types(),
1683+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1684+
dtypesIfCUDA=floating_types_and(torch.half)),
1685+
1686+
ForeachUnaryFuncInfo('erfc',
1687+
dtypes=floating_types(),
1688+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1689+
dtypesIfCUDA=floating_types_and(torch.half)),
1690+
1691+
ForeachUnaryFuncInfo('expm1',
1692+
dtypes=floating_types(),
1693+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1694+
dtypesIfCUDA=floating_types_and(torch.half)),
1695+
1696+
ForeachUnaryFuncInfo('floor',
1697+
dtypes=floating_types(),
1698+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1699+
dtypesIfCUDA=floating_types_and(torch.half)),
1700+
1701+
ForeachUnaryFuncInfo('log1p',
1702+
dtypes=floating_types(),
1703+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1704+
dtypesIfCUDA=floating_types_and(torch.half)),
1705+
1706+
ForeachUnaryFuncInfo('round',
1707+
dtypes=floating_types(),
1708+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1709+
dtypesIfCUDA=floating_types_and(torch.half)),
1710+
1711+
ForeachUnaryFuncInfo('frac',
1712+
dtypes=floating_types(),
1713+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1714+
dtypesIfCUDA=floating_types_and(torch.half)),
1715+
1716+
ForeachUnaryFuncInfo('reciprocal',
1717+
dtypes=floating_types(),
1718+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1719+
dtypesIfCUDA=floating_types_and(torch.half)),
1720+
1721+
ForeachUnaryFuncInfo('sigmoid',
1722+
dtypes=floating_types(),
1723+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1724+
dtypesIfCUDA=floating_types_and(torch.half)),
1725+
1726+
ForeachUnaryFuncInfo('trunc',
1727+
dtypes=floating_types(),
1728+
dtypesIfCPU=floating_types_and(torch.bfloat16),
1729+
dtypesIfCUDA=floating_types_and(torch.half)),
1730+
1731+
ForeachUnaryFuncInfo('abs',
1732+
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
1733+
dtypesIfCPU=all_types_and_complex_and(torch.bfloat16, torch.half),
1734+
dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
1735+
safe_casts_outputs=False)
1736+
]
16091737

16101738
# Operator database (sorted alphabetically)
16111739
op_db: List[OpInfo] = [

0 commit comments

Comments
 (0)