Skip to content

Commit 1676c7d

Browse files
anjali411facebook-github-bot
authored andcommitted
Added autograd tests, disabled jit autograd tests for complex and added a separate list for tests for complex dtype only (#38399)
Summary: Pull Request resolved: #38399 Test Plan: Imported from OSS Differential Revision: D21555941 Pulled By: anjali411 fbshipit-source-id: ea9f5a76590c5bab3df6a540617b074238bfb535
1 parent 53439be commit 1676c7d

3 files changed

Lines changed: 37 additions & 6 deletions

File tree

test/test_autograd.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4185,11 +4185,18 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
41854185
test_case.assertEqualTypeString(self_variable, self_variable.grad)
41864186
test_case.assertEqual(self_variable.size(), self_variable.grad.size())
41874187

4188+
# this list corresponds to ops which have separate tests defined for complex dtypes in
4189+
# common_methods_invocations.py
4190+
# test for these ops with 'complex' in variant should only run for complex and
4191+
# the tests for these ops which do not have 'complex' in variant should not run for complex
4192+
# and only run for floating point
4193+
separate_complex_tests = ['log', 'log10', 'log1p', 'log2', 'reciprocal']
4194+
41884195
# white list for complex
41894196
complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'zero_', 'clone',
41904197
'tril', 'triu', 'fill_', 'eq_', 'ne_', 'permute', 'squeeze', 'unsqueeze',
41914198
'chunk', 'split', 'split_with_sizes', 'resize', 'resize_as', 'sin', 'cos',
4192-
'__rmul__', '__rdiv__']
4199+
'__rmul__', '__rdiv__', 'transpose', 'round'] + separate_complex_tests
41934200

41944201
def add_test(
41954202
name,
@@ -4206,17 +4213,29 @@ def add_test(
42064213
if variant_name != '':
42074214
basic_test_name += '_' + variant_name
42084215

4216+
if name in separate_complex_tests and 'complex' in variant_name:
4217+
run_only_complex = True
4218+
else:
4219+
run_only_complex = False
4220+
42094221
for dtype in [torch.double, torch.cdouble]:
42104222
for dim_perm in product([-1, 1], repeat=len(dim_args_idx)):
42114223
test_name = basic_test_name
42124224
new_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg for i, arg in enumerate(args)]
42134225
test_name = basic_test_name + ''.join('_neg' + str(i) for i, idx in enumerate(dim_perm) if idx < 0)
4226+
42144227
if dtype.is_complex:
42154228
# TODO: remove this. this is temporary while we ramp up the complex support.
42164229
if name in complex_list and 'scalar' not in test_name and 'constant' not in test_name:
4217-
test_name = test_name + '_complex'
4230+
if name in separate_complex_tests and 'complex' not in variant_name:
4231+
continue
4232+
if not run_only_complex:
4233+
test_name = test_name + '_complex'
42184234
else:
42194235
continue
4236+
elif run_only_complex:
4237+
continue
4238+
42204239
new_args = tuple(new_args)
42214240

42224241
# for-loop bodies don't define scopes, so we have to save the variables
@@ -4239,8 +4258,6 @@ def check(name):
42394258
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
42404259
output_tensor = getattr(self_tensor, name)(*args_tensor, **kwargs_variable)
42414260
if not isinstance(output_tensor, torch.Tensor) and not istuple(output_tensor):
4242-
# TODO: I'm not sure why we insert an outer dimension
4243-
# here, seems a bit strange
42444261
if dtype.is_complex:
42454262
output_tensor = torch.tensor((output_tensor, ), dtype=torch.cfloat, device=device)
42464263
else:

test/test_jit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17948,6 +17948,12 @@ def add_autograd_test(
1794817948
skipTestIf=(),
1794917949
output_process_fn=lambda x: x,
1795017950
kwargs=None):
17951+
17952+
# Disable complex tests
17953+
# TODO: Add complex support for jit
17954+
if 'complex' in variant_name:
17955+
return
17956+
1795117957
basic_test_name = 'test_' + name
1795217958
if variant_name != '':
1795317959
basic_test_name += '_' + variant_name

torch/testing/_internal/common_methods_invocations.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def method_tests():
176176
('transpose', (1, 2, 3), (1, 2), 'dim', (True,), [0, 1]),
177177
('transpose', (), (0, 0), 'scalar', (True,)),
178178
('transpose', (1,), (0, 0), '1d', (True,)),
179-
('transpose', torch.rand(L, L), (0, 1), '2d', (True,)),
180-
('transpose', torch.rand(S, S, S), (2, 0), '3d', (True,)),
179+
('transpose', (L, L), (0, 1), '2d', (True,)),
180+
('transpose', (S, S, S), (2, 0), '3d', (True,)),
181181
('t', (1, 2), NO_ARGS, '', (True,)),
182182
('view', (S, S, S), (S * S, S), '', (True,)),
183183
('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
@@ -238,6 +238,12 @@ def method_tests():
238238
('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)),
239239
('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
240240
('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
241+
('log', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
242+
('log', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
243+
('log10', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
244+
('log10', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
245+
('log2', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
246+
('log2', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
241247
('tanh', (S, S, S), NO_ARGS, '', (True,)),
242248
('tanh', (), NO_ARGS, 'scalar', (True,)),
243249
('sigmoid', (S, S, S), NO_ARGS, '', (True,)),
@@ -273,6 +279,8 @@ def method_tests():
273279
('atan2', (S, 1, S), ((S, S),), 'broadcast_all'),
274280
('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS, '', (True,)),
275281
('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar', (True,)),
282+
('reciprocal', torch.randn(S, S, S, dtype=torch.cdouble) + 0.1, NO_ARGS, 'complex', (True,)),
283+
('reciprocal', uniform_scalar(0.1j), NO_ARGS, 'complex_scalar', (True,)),
276284
('round', (S, S, S), NO_ARGS, '', (True,)),
277285
('round', (), NO_ARGS, 'scalar', (True,)),
278286
('sign', (S, S, S), NO_ARGS),

0 commit comments

Comments
 (0)