Skip to content

Commit 47f0bda

Browse files
peterbell10facebook-github-bot
authored andcommitted
Improve complex support in common_nn test machinery (#50593)
Summary: Pull Request resolved: #50593 There are no equivalent to torch.FloatTensor, torch.cuda.FloatTensor for complex types. So `get_gpu_type` and `get_cpu_type` are broken for complex tensors. Also found a few places that explicitly cast inputs to floating point types, which would drop the imaginary component before running the test. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25954050 Pulled By: mruberry fbshipit-source-id: 1fa8e5af233aa095c839d5e2f860564baaf92aef
1 parent 9ac30d9 commit 47f0bda

3 files changed

Lines changed: 42 additions & 14 deletions

File tree

test/test_jit.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15885,8 +15885,14 @@ def create_nn_module(*args, **kwargs):
1588515885
return module(*args)
1588615886

1588715887
# Set up inputs from tuple of sizes or constructor fn
15888+
dtype = torch.double
1588815889
if 'input_fn' in kwargs:
1588915890
input = kwargs['input_fn']()
15891+
if isinstance(input, Tensor):
15892+
input = (input,)
15893+
15894+
if all(tensor.is_complex() for tensor in input):
15895+
dtype = torch.cdouble
1589015896
else:
1589115897
input = (kwargs['input_size'],)
1589215898

@@ -15903,7 +15909,7 @@ def create_nn_module(*args, **kwargs):
1590315909
if 'extra_args' in kwargs:
1590415910
input = input + kwargs['extra_args']
1590515911

15906-
args_variable, kwargs_variable = create_input(input)
15912+
args_variable, kwargs_variable = create_input(input, dtype=dtype)
1590715913
f_args_variable = deepcopy(unpack_variables(args_variable))
1590815914

1590915915
# Check against Python module as reference

torch/testing/_internal/common_nn.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4973,7 +4973,7 @@ def test_cuda(self, test_case):
49734973
# Run backwards on CPU and GPU and compare results
49744974
for _ in range(5):
49754975
cpu_gradOutput = cpu_output.clone().normal_()
4976-
gpu_gradOutput = cpu_gradOutput.type('torch.cuda.FloatTensor')
4976+
gpu_gradOutput = cpu_gradOutput.type_as(gpu_output)
49774977
cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput)
49784978
gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput)
49794979
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
@@ -5047,6 +5047,7 @@ def __init__(self, *args, **kwargs):
50475047
self.check_inplace = kwargs.get('check_inplace', False)
50485048
self.check_gradgrad = kwargs.get('check_gradgrad', True)
50495049
self.skip_double = kwargs.get('skip_double', False)
5050+
self.skip_half = kwargs.get('skip_half', False)
50505051
self.with_tf32 = kwargs.get('with_tf32', False)
50515052
self.tf32_precision = kwargs.get('tf32_precision', 0.001)
50525053
self.test_cpu = kwargs.get('test_cpu', True)
@@ -5136,15 +5137,32 @@ def assert_module_parameters_are(tensor_type, device_id=None):
51365137
assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined]
51375138
else:
51385139
# check that float()/double() casters work correctly
5140+
def to_type(tensor, real, complex):
5141+
if tensor.is_complex():
5142+
return tensor.to(complex)
5143+
elif tensor.is_floating_point():
5144+
return tensor.to(real)
5145+
else:
5146+
return tensor
5147+
5148+
def to_half(x):
5149+
# TODO: torch.complex32 when properly supported
5150+
return to_type(x, torch.float16, None)
5151+
5152+
def to_single(x):
5153+
return to_type(x, torch.float32, torch.complex64)
5154+
5155+
def to_double(x):
5156+
return to_type(x, torch.float64, torch.complex128)
51395157

51405158
# to float
5141-
input_tuple = tuple(t.float() if not isinstance(t, torch.LongTensor) else t for t in input_tuple)
5159+
input_tuple = tuple(to_single(t) for t in input_tuple)
51425160
module.float()
51435161
module(*input_tuple)
51445162
assert_module_parameters_are(torch.FloatTensor)
51455163

51465164
# and back to double
5147-
input_tuple = tuple(t.double() if not isinstance(t, torch.LongTensor) else t for t in input_tuple)
5165+
input_tuple = tuple(to_double(t) for t in input_tuple)
51485166
module.double()
51495167
module(*input_tuple)
51505168
assert_module_parameters_are(torch.DoubleTensor)
@@ -5154,8 +5172,7 @@ def assert_module_parameters_are(tensor_type, device_id=None):
51545172
# and that float() casts parameters correctly
51555173

51565174
# to GPU0
5157-
input_tuple = tuple(
5158-
t.float().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple)
5175+
input_tuple = tuple(to_single(t).cuda() for t in input_tuple)
51595176
module.float().cuda()
51605177
module(*input_tuple)
51615178
assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
@@ -5189,18 +5206,17 @@ def assert_module_parameters_are(tensor_type, device_id=None):
51895206

51905207
if not self.skip_double:
51915208
# test double()
5192-
input_tuple = tuple(
5193-
t.double().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple)
5209+
input_tuple = tuple(to_double(t).cuda() for t in input_tuple)
51945210
module.double().cuda()
51955211
module(*input_tuple)
51965212
assert_module_parameters_are(torch.cuda.DoubleTensor, 0) # type: ignore[attr-defined]
51975213

51985214
# test half()
5199-
input_tuple = tuple(
5200-
t.half().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple)
5201-
module.half().cuda()
5202-
module(*input_tuple)
5203-
assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined]
5215+
if not self.skip_half:
5216+
input_tuple = tuple(to_half(t).cuda() for t in input_tuple)
5217+
module.half().cuda()
5218+
module(*input_tuple)
5219+
assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined]
52045220
torch.set_num_threads(num_threads)
52055221

52065222
def _get_target(self):

torch/testing/_internal/jit_metaprogramming_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,14 @@ def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
527527
constructor_args = kwargs.get('constructor_args', ())
528528

529529
# Set up inputs from tuple of sizes or constructor fn
530+
input_dtype = torch.double
530531
if 'input_fn' in kwargs:
531532
input = kwargs['input_fn']()
533+
if isinstance(input, torch.Tensor):
534+
input = (input,)
535+
536+
if all(tensor.is_complex() for tensor in input):
537+
input_dtype = torch.cdouble
532538
else:
533539
input = (kwargs['input_size'],)
534540

@@ -543,7 +549,7 @@ def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
543549
input = (input,)
544550
input = input + (kwargs['target_fn'](),)
545551

546-
args_variable, kwargs_variable = create_input(input)
552+
args_variable, kwargs_variable = create_input(input, dtype=input_dtype)
547553
f_args_variable = deepcopy(unpack_variables(args_variable))
548554
out_var = deepcopy(f_args_variable)
549555

0 commit comments

Comments
 (0)