|
def add_test( |
|
name, |
|
self_size, |
|
args, |
|
variant_name='', |
|
check_ad=(), # only used in test_jit |
|
dim_args_idx=(), |
|
skipTestIf=(), |
|
output_process_fn=lambda x: x, |
|
kwargs=None): |
|
kwargs = kwargs if kwargs else {} |
|
basic_test_name = 'test_' + name |
|
if variant_name != '': |
|
basic_test_name += '_' + variant_name |
|
|
|
if name in separate_complex_tests and 'complex' in variant_name: |
|
run_only_complex = True |
|
else: |
|
run_only_complex = False |
|
|
|
for dtype in [torch.double, torch.cdouble]: |
|
for dim_perm in product([-1, 1], repeat=len(dim_args_idx)): |
|
test_name = basic_test_name |
|
new_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg for i, arg in enumerate(args)] |
|
test_name = basic_test_name + ''.join('_neg' + str(i) for i, idx in enumerate(dim_perm) if idx < 0) |
|
|
|
if dtype.is_complex: |
|
# TODO: remove this. this is temporary while we ramp up the complex support. |
|
if name in complex_list: |
|
if name in separate_complex_tests and 'complex' not in variant_name: |
|
continue |
|
if not run_only_complex: |
|
test_name = test_name + '_complex' |
|
else: |
|
continue |
|
elif run_only_complex: |
|
continue |
|
|
|
new_args = tuple(new_args) |
|
|
|
# for-loop bodies don't define scopes, so we have to save the variables |
|
# we want to close over in some way |
|
def do_test(self, device, dtype=dtype, name=name, self_size=self_size, args=new_args, test_name=test_name, |
|
output_process_fn=output_process_fn): |
|
def check(name): |
|
is_magic_method = name[:2] == '__' and name[-2:] == '__' |
|
is_inplace = name[-1] == "_" and not is_magic_method |
|
self_variable = create_input((self_size,), dtype=dtype, device=device)[0][0] |
|
# FixMe: run grad checks on inplace self |
|
if is_inplace: |
|
self_variable.requires_grad = False |
|
# need to record this because methods can change the size (e.g. unsqueeze) |
|
args_variable, kwargs_variable = create_input(args, requires_grad=not is_inplace, |
|
call_kwargs=kwargs, dtype=dtype, device=device) |
|
self_tensor = deepcopy(self_variable) |
|
args_tensor = deepcopy(unpack_variables(args_variable)) |
|
if not exclude_tensor_method(name, test_name): |
|
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable) |
|
output_tensor = getattr(self_tensor, name)(*args_tensor, **kwargs_variable) |
|
if not isinstance(output_tensor, torch.Tensor) and not isinstance(output_tensor, tuple): |
|
if dtype.is_complex: |
|
output_tensor = torch.tensor((output_tensor, ), dtype=torch.cfloat, device=device) |
|
else: |
|
output_tensor = torch.tensor((output_tensor, ), dtype=torch.float, device=device) |
|
self.assertEqual(unpack_variables(output_variable), output_tensor) |
|
# TODO: check that both have changed after adding all inplace ops |
|
|
|
def fn(*inputs): |
|
output = getattr(inputs[0], name)(*inputs[1:], **kwargs) |
|
return output_process_fn(output) |
|
|
|
if not is_inplace and name not in EXCLUDE_GRADCHECK: |
|
check_batched_grad = test_name not in EXCLUDE_BATCHED_GRAD_TESTS |
|
run_grad_and_gradgrad_checks(self, name, test_name, fn, |
|
output_variable, (self_variable,) + args_variable, |
|
check_batched_grad=check_batched_grad) |
|
|
|
# functional interface tests |
|
torch_fn = getattr_qualified(torch, name) |
|
if torch_fn is not None and name not in EXCLUDE_FUNCTIONAL: |
|
def fn(*inputs): |
|
output = torch_fn(*inputs, **kwargs) |
|
return output_process_fn(output) |
|
|
|
f_args_variable = (self_variable,) + args_variable |
|
f_args_tensor = (self_tensor,) + args_tensor |
|
# could run the gradchecks again, but skip since we did it for the methods above. |
|
run_gradcheck = exclude_tensor_method(name, test_name) and not is_inplace and name not in EXCLUDE_GRADCHECK |
|
run_functional_checks(self, test_name, name, fn, |
|
run_gradcheck, f_args_variable, f_args_tensor) |
|
|
|
# check for correct type of input and input.grad |
|
if not is_inplace: |
|
self_variable = create_input((self_size,), requires_grad=True, dtype=dtype)[0][0] |
|
args_variable, kwargs_variable = create_input(args, requires_grad=False, call_kwargs=kwargs, dtype=dtype) |
|
if hasattr(self_variable, name): |
|
attribute_result = getattr(self_variable, name) |
|
if callable(attribute_result): |
|
output_variable = attribute_result(*args_variable, **kwargs_variable) |
|
else: |
|
self.assertTrue(len(args_variable) == 0) |
|
self.assertTrue(len(kwargs_variable) == 0) |
|
output_variable = attribute_result |
|
else: |
|
self_and_args_variable = (self_variable,) + args_variable |
|
output_variable = torch_fn(*self_and_args_variable, **kwargs_variable) |
|
if isinstance(output_variable, torch.autograd.Variable): |
|
if output_variable.is_sparse: |
|
rand = randn_like(output_variable.to_dense()).to_sparse() |
|
else: |
|
rand = randn_like(output_variable) |
|
output_variable.backward(rand) |
|
self.assertTrue(type(self_variable) == type(self_variable.grad)) |
|
self.assertTrue(self_variable.size() == self_variable.grad.size()) |
|
|
|
# compare grads to inplace grads |
|
inplace_name = name + '_' |
|
# can't broadcast inplace to left hand side |
|
skip_inplace = ('broadcast_lhs' in test_name or |
|
'broadcast_all' in test_name or |
|
'atanh' in test_name or |
|
'acosh' in test_name or |
|
'asinh' in test_name or |
|
'abs_complex' in test_name or |
|
'abs_scalar_complex' in test_name) |
|
if hasattr(torch.ones(1), inplace_name) and not skip_inplace: |
|
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable) |
|
if not isinstance(output_variable, tuple): |
|
output_variable = (output_variable,) |
|
inplace_self_variable = deepcopy(self_variable) |
|
inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i |
|
for i in (inplace_self_variable,)) |
|
inplace_args_variable = deepcopy(args_variable) |
|
inplace_args_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i |
|
for i in inplace_args_variable) |
|
|
|
inplace_output_variable = ( |
|
getattr(inplace_self_variable_copy[0], inplace_name)(*inplace_args_variable_copy, |
|
**kwargs_variable)) |
|
if not isinstance(inplace_output_variable, tuple): |
|
inplace_output_variable = (inplace_output_variable,) |
|
self.assertEqual(inplace_output_variable, output_variable) |
|
# Check that gradient is the same |
|
for inp_i, i in zip((inplace_self_variable,) + inplace_args_variable, |
|
(self_variable,) + args_variable): |
|
if not isinstance(inp_i, torch.Tensor): |
|
assert not isinstance(i, torch.Tensor) |
|
continue |
|
if inp_i.grad is not None: |
|
with torch.no_grad(): |
|
inp_i.grad.zero_() |
|
if i.grad is not None: |
|
with torch.no_grad(): |
|
i.grad.zero_() |
|
for i_o, o in zip(inplace_output_variable, output_variable): |
|
if dtype.is_complex: |
|
grad = randn_like(i_o).to(torch.cdouble) |
|
else: |
|
grad = randn_like(i_o).double() |
|
i_o.backward(grad) |
|
o.backward(grad) |
|
for inp_i, i in zip((inplace_self_variable,) + inplace_args_variable, |
|
(self_variable,) + args_variable): |
|
if not isinstance(inp_i, torch.Tensor): |
|
continue |
|
self.assertEqual(inp_i.grad, i.grad) |
|
|
|
check(name) |
|
inplace_name = name + '_' |
|
# can't broadcast inplace to left hand side |
|
broadcast_skip_inplace = 'broadcast_lhs' in test_name or 'broadcast_all' in test_name |
|
# skip C -> R inplace tests |
|
skip_c_to_r_inplace = 'abs_complex' in test_name or 'abs_scalar_complex' in test_name |
|
skip_inplace = broadcast_skip_inplace or skip_c_to_r_inplace |
|
if hasattr(torch.ones(1), inplace_name) and not skip_inplace: |
|
check(inplace_name) |
|
|
|
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name |
|
|
|
for skip in skipTestIf: |
|
do_test = skip(do_test) |
|
|
|
setattr(TestAutogradDeviceType, test_name, do_test) |
Previous discussion: #755
High Priority
Operator support
torch.linalg.adjoint()torch.linalg.eig(both CPU and CUDA, waiting on TH to ATen port)torch.complexandtorch.polartorch.arangetorch.triangular_solvetorch.specialTH to ATen ports that would greatly help unblock some of the work listed here:modeMigratemodefrom the TH to Aten (CPU) #24731, Migrate_modefrom the TH to Aten (CUDA) #24526renormMigraterenormandrenorm_from the TH to Aten (CUDA) #24616put_Migrateput_from the TH to Aten (CPU) #24751, Migrateput_from the TH to Aten (CUDA) #24614takeMigratetakefrom the TH to Aten (CUDA) #24640, Migratetakefrom the TH to Aten (CPU) #24772nonzero (CPU)Migratenonzerofrom the TH to Aten (CPU) #24745index_fill_Migrateindex_fill_from the TH to Aten (CUDA) #24577, Migrateindex_fill_from the TH to Aten (CPU) #24714index_copyMigrate_index_copy_from the TH to Aten (CUDA) #24523, Migrate_index_copy_from the TH to Aten (CPU) #24670masked_scattermasked_fillMigratemasked_fillfrom TH to ATen (CUDA) #49543Complex support for real valued
torch.nnloss functionsTasks listed on #46642 along with a sample PR for
torch.nn.L1Loss.Complex Autograd supported but untested:
torch.disttorch.masked_filltorch.Tensor.masked_scatter_torch.Tensor.scatter_Enhance testing
pytorch/test/test_autograd.py
Lines 5099 to 5122 in a9e4bb5
OpInfobased tests. After this migration is complete, remove the method_test generation logic for complex types:pytorch/test/test_autograd.py
Lines 5131 to 5313 in a9e4bb5
R->Ctests for complex functions. (contact @anjali411 if you'd like to work on this task)Complex Autograd support for the following ops supported for complex:
torch.sigmoidtorch.eigtorch.lutorch.cross(Also add tests for function correctness)torch.cumsumtorch.cumprodtorch.linalg.dettorch.Tensor.indextorch.lerptorch.prodtorch.Tensor.put_rsubtorch.symeigindex_puttorch.unfoldtorch.diagtorch.Tensor.masked_fill_torch.tracetorch.polarLinear Algebra Ops
torch.cholesky_inversetorch.chain_matmultorch.ger / torch.outerOther ops:
torch.isnanandtorch.isfinitetorch.linspacetorch.logspaceSpectral Op Migration: #42175
complex views
tensor.real
tensor.imag
torch.view_as_real
torch.view_as_complex
[Factory Functions]
torch.complex , torch.polar 35312
torch.rand
torch.randn
torch.from_numpy(complex_array)
blocked on: Migrateset_from the TH to Aten (CPU) #24759 Migrateset_from the TH to Aten (CUDA) #24623Other functions:
Complex Number Support for
torch.nn.distributed: #45760Complex Autograd Guide:
To understand and obtain the formula for complex derivatives, check out: https://pytorch.org/docs/master/notes/autograd.html#what-are-complex-derivatives
The list of operators for which complex autograd is supported and tested can be found here:
pytorch/tools/autograd/gen_variable_type.py
Line 154 in 72bc3d9
derivatives.yaml, then to enable complex backward for that operator, it would have to be added toGRADIENT_IMPLEMENTED_FOR_COMPLEXingen_variable_type.py.To run a common_methods test for complex, add an entry here:
pytorch/test/test_autograd.py
Line 4927 in b61671c
To run a common_methods test only for complex, add an entry here:
pytorch/test/test_autograd.py
Line 4920 in b61671c
Autograd tasks:
Gradcheck logic for
C -> C, C -> R, R -> Cfunctions. (Complex gradcheck logic #43208 )Disable complex autograd for operators not tested for complex. (Add allowlist for complex backward #45461 )
scalar mul backward
CUDA complex acos fails gradcheck on Windows
Update pow_backward
Other tasks:
Discussions:
Numpy parity:
Inconsistency between torch.abs and np.abs 33567
torch.angle is divergent from numpy.angle
torch.abs(complex) is divergent from NumPy on vectorized NaN values
c10::complex tracker: #35284 (comment)
cc @ezyang @anjali411 @dylanbespalko @mruberry