Today in test_ops.py we have a single test for out= behavior:
|
def test_out(self, device, dtype, op): |
This test
- only works if the operator has a single tensor output
- only asserts that the out= variant works when passing an empty_like of the output
This is a good start, but we can do better.
Conceptually, operators that support out= should work as follows (see also #41027 (comment) for additional background):
- the operator's computation is independent of the out= argument
- the result of the computation is "safe" copied to the out= argument
"Safe" copying is like regular copying except it requires that:
- the tensors are on the same device (Note: not just same device type, same device)
- when the operation has a "computation type", then the "dtype kind" of the copied to tensor is at least the "dtype kind" of the copied from tensor; PyTorch has four "dtype kinds": boolean, integer, float, and complex. So we do not permit computing in complex and copying to float, for example, or computing in an integer type and copying to bool.
There is one caveat to this behavior: if the out= tensor has zero elements then it may be resized before the copy. See:
|
CAFFE2_API void resize_output(Tensor& output, IntArrayRef shape); |
If the out= tensor is resized then the operation may also edit its strides. The device or dtype of the out= tensor may never change, however.
To test out= better, then, PyTorch needs to add multiple tests that do the following:
(updated to reflect #53259)
work with operators that produce any number of tensor outputs
this can probably be detected programmatically and the metadata supports_tensor_out removed
pass tensors full of NaNs as out tensors to verify the out= tensors are not read and are completely written to
pass tensors with no elements, tensors of the correct size, discontiguous tensors of the correct size, and tensors of the incorrect size to out to validate the resizing logic
tensors with no elements should be resized, tensors of the correct size should be used, and tensors of the incorrect size should throw a warning
verify that functions validate out's dtype and device properly
- verify the computation with out= is equivalent to the computation without it + copying to the out= tensor
- pass the input tensor to out= and verify either that the input tensor is correctly modified in place or that a runtime error is thrown
- verify that functions enforce "safe" dtype casting (as described above)
- validate that inplace operations are equivalent to using the inplace argument as the out parameter
To avoid overburdening the CI machinery, these tests should probably continue to run on only the first sample input and not every sample input. However, we should also verify that each operator's "first" sample input is a nontrivial and nondegenerate input, and document with a comment that this should be the case. For example, sample_inputs_unary returns a 1D tensor as its first sample input:
|
return (SampleInput(make_tensor((L,), device, dtype, |
Which seems correct. If it returned a scalar tensor as its first sample input that would probably be less interesting.
cc @mruberry @VitalyFedyunin @walterddr
Today in test_ops.py we have a single test for out= behavior:
pytorch/test/test_ops.py
Line 306 in efc0906
This test
This is a good start, but we can do better.
Conceptually, operators that support out= should work as follows (see also #41027 (comment) for additional background):
"Safe" copying is like regular copying except it requires that:
There is one caveat to this behavior: if the out= tensor has zero elements then it may be resized before the copy. See:
pytorch/aten/src/ATen/native/Resize.h
Line 17 in 6d6e9ab
If the out= tensor is resized then the operation may also edit its strides. The device or dtype of the out= tensor may never change, however.
To test out= better, then, PyTorch needs to add multiple tests that do the following:
(updated to reflect #53259)
work with operators that produce any number of tensor outputsthis can probably be detected programmatically and the metadatasupports_tensor_outremovedpass tensors full of NaNs as out tensors to verify the out= tensors are not read and are completely written topass tensors with no elements, tensors of the correct size, discontiguous tensors of the correct size, and tensors of the incorrect size to out to validate the resizing logictensors with no elements should be resized, tensors of the correct size should be used, and tensors of the incorrect size should throw a warningverify that functions validate out's dtype and device properlyTo avoid overburdening the CI machinery, these tests should probably continue to run on only the first sample input and not every sample input. However, we should also verify that each operator's "first" sample input is a nontrivial and nondegenerate input, and document with a comment that this should be the case. For example,
sample_inputs_unaryreturns a 1D tensor as its first sample input:pytorch/torch/testing/_internal/common_methods_invocations.py
Line 219 in efc0906
Which seems correct. If it returned a scalar tensor as its first sample input that would probably be less interesting.
cc @mruberry @VitalyFedyunin @walterddr