Support torch.bitwise_{left/right}_shift and __rlshift__, __rrshift__#59544
Support torch.bitwise_{left/right}_shift and __rlshift__, __rrshift__#59544asi1024 wants to merge 11 commits intopytorch:masterfrom
torch.bitwise_{left/right}_shift and __rlshift__, __rrshift__#59544Conversation
💊 CI failures summary and remediationsAs of commit e793a3c (more details on the Dr. CI page and at hud.pytorch.org/pr/59544): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Consider making it structured! https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md
There was a problem hiding this comment.
I'm trying to make the definition structured (67274f0), but the build fails at BinaryShiftOpsKernels.cu.
/home/asi1024/pytorch/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu(57): error: no instance of constructor "at::native::<unnamed>::RegisterCUDADispatch<FnPtr, T>::RegisterCUDADispatch [with FnPtr=void (*)(at::TensorIterator &), T=at::native::lshift_stub]" matches the argument list
argument types are: (at::native::lshift_stub, void (*)(at::TensorIterator &))
1 error detected in the compilation of "/home/asi1024/pytorch/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu".
CMake Error at torch_cuda_generated_BinaryShiftOpsKernels.cu.o.Release.cmake:281 (message):
Error generating file
/home/asi1024/pytorch/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/./torch_cuda_generated_BinaryShiftOpsKernels.cu.o
Could you please help me fix this error?
There was a problem hiding this comment.
lshift_stub needs to take a TensorIteratorBase& not TensorIterator&
fdf2490 to
67274f0
Compare
There was a problem hiding this comment.
This is really cool. We'll also need to add a few tests to test_binary_ufuncs.py:
- compare bitwise_left_shift and bitwise_right_shift tensor x tensor, tensor x scalar, and scalar x tensor variants against NumPy's left_shift and right_shift as references
- a sanity check that operator left shift works for tensor x tensor, tensor x scalar, and scalar x tensor, too
- check that type promotion works as expected, with the result dtype of the operation being the same dtype as the input tensor
This is required because OpInfos don't automatically test that the operation implements the same behavior as a reference except for UnaryUfuncInfos (although we're expanding this functionality soon to more OpInfos), OpInfos don't support operators (although we'll probably add that functionality soon, too!), and OpInfos don't support type promotion testing (although we're working on adding that soon, too!).
There was a problem hiding this comment.
This is interesting because the OpInfo shows that floating point types works, and testing in PyTorch also shows that left and right shift work with floats. What about adding a check that both input and other are integer types in the code? Then this sentence would be updated to:
"Both :attr:`input` and :attr:`other` must have integer dtypes."
It can be its own paragraph after the first sentence.
There was a problem hiding this comment.
Another important here, however, is that this function does not participate in type promotion. The returned dtype should be the same as input's dtype. So I guess this could say:
"Both :attr:input and :attr:other must have integer dtypes. The result will have the same dtype as :attr:`input`."
There was a problem hiding this comment.
For a left shift it doesn't matter whether these are logical or arithmetic shifts (https://en.wikipedia.org/wiki/Arithmetic_shift#Equivalence_of_arithmetic_and_logical_left_shifts_and_multiplication), but for a right shift it does matter.
Is this a logical or an arithmetic shift? We should be clear to document it. Note I think the Python Array API underspecifies its "bitwise_right_shift", actually. cc @rgommers
There was a problem hiding this comment.
The <<, >> operators, numpy and pytorch are all consistent with each it looks like:
>>> 10 << 1
20
>>> 10 >> 1
5
>>> -10 >> 1
-5
>>> -11 >> 1
-6
>>> t = torch.Tensor([10])
>>> t
tensor([10.])
>>> t << 1
tensor([20.])
>>> t >> 1
tensor([5.])
>>> t = torch.Tensor([-10])
>>> t << 1
tensor([-20.])
>>> t >> 1
tensor([-5.])
>>> def f(x, shift=1):
... x2 = np.right_shift(x, shift)
... print(f'Original ({x}): {np.binary_repr(x)}')
... print(f'Shifted ({x2}): {np.binary_repr(x2)}')
...
...
>>> f(10)
Original (10): 1010
Shifted (5): 101
>>> f(-10)
Original (-10): -1010
Shifted (-5): -101
IIRC this is arithmetic shifting for negative integers. I'll open an issue to double check and clarify this.
There was a problem hiding this comment.
This is resolved in the specification now, it says to use arithmetic shift.
There was a problem hiding this comment.
This is missing the out argument: (input, other, *, out)
mruberry
left a comment
There was a problem hiding this comment.
Hey @asi1024!
Thanks for this PR. I made a few comments inline. We should be very careful to match the Python Array API's specified behavior, and I suggest adding a few more tests. Also, we should be very precise about specifying the kind of shift done (logical or arithmetic), especially for the right shift where these shifts are distinct from each other.
1a94e55 to
de0f6f1
Compare
|
@mruberry Thank you for your review! |
There was a problem hiding this comment.
Let's talk about this behavior and the current shift functions.
PyTorch defines the following shift functions today: lshift, ilshift, rshift, irshift. And these and the current PR have the same type promotion bug, as shown by this sample program:
t = torch.tensor((1, 2, 3))
t << 3.2
: tensor([ 8, 16, 24])
The current shift functions and these new ones also have divergent dtype support, as shown here:
t = torch.tensor((1., 2, 3))
t << 3.2
: tensor([ 9.1896, 18.3792, 27.5688])
This creates a tricky situation. Here's what I'd like to propose we do:
- the new functions should be implemented as aliases of the existing shift functions plus the new reflective shifts (they should just call the existing functions)
- the incorrect type promotion behavior should be preserved, because changing it would be BC-breaking
This will preserve float support, which would also require a BC-breaking change, but it will ensure that PyTorch's shift functionality is consistent.
There was a problem hiding this comment.
@mruberry I agree to support floating-point type arguments to make the behavior consistent with existing __lshift__, and will fix the implementation!
However, __lshift__ and __ilshift__ don't have an interface for bitwise_left_shift_out, so the new functions cannot be implemented as aliases of existing functions. So I am planning to support bitwise_left_shift_out that calls lshift_stub. Is it OK?
There was a problem hiding this comment.
Absolutely, you're right. Sorry, I didn't mean to suggest that we not add any new code, just that we be sure to reuse the same behavior (and its natural extension)
There was a problem hiding this comment.
Let's create "test_shift" and do the following:
- use the @dtypes decorator and select dtypes torch.uint8, torch.int64
- enumerate pairs (torch_op, numpy_op) that define the functions to test
- the pairs should be (torch.bitwise_left_shift, np.left_shift), (operator.left_shift, operator.left_shift), (torch.bitwise_right_shift, np.bitwise_right_shift), (operator.right_shift, operator.right_shift)
- compare the results on a variety of test cases that use the test's device and dtype
- don't bother testing the out or inplace variants, those are tested automatically when creating an opinfo for torch.bitwise_left_shift and torch.bitwise_right shift
See https://docs.python.org/3/library/operator.html for more on using the operator module.
fyi @pmeier, this highlights the importance of automated binary elementwise testing, and @kshitij12345, I expect we'll increasingly be interested in supporting operators using OpInfos (operators are like aliases in that operator.left_shift is an alias for torch.bitwise_left_shift, but there are no method variants for operators and the inplace operator variants are different, because the inplace version of the operator is operator.ilshift, so we'll need to think of the appropriate way to model them -- maybe as totally different ops with their own OpInfos?)
FYIs aside, we're not done with testing yet, @asi1024! We also need to add a test that validates PyTorch's type promotion behavior and support for floating point inputs. For this test we can just "spot check" one input for each case we want to validate. In particular, we should check:
- int tensor x float scalar == int tensor x trunc(float scalar) type promotion
- float tensor x int scalar works as expected (it's OK to use a 'golden value' here)
- float tensor x float scalar works as expected (it's OK to use a 'golden value' here)
Does that make sense, @asi1024? I know that's a lot, but we should be very careful validating the behavior here.
There was a problem hiding this comment.
These entries need to be added to tensors.rst, too, so they appear in the docs
There was a problem hiding this comment.
These entries need to be added to torch.rst to appear in the docs
There was a problem hiding this comment.
This will no longer be true once the current left shift implementation is used. What actually happens with floating point inputs? We should probably document the behavior using cases ("If input has an integer dtype... If input has a floating point dtype..."). We should be sure to document the odd behavior that int tensor x floating point scalar inputs have their floating point scalar values cast to the tensor's dtype.
There was a problem hiding this comment.
input (Tensor or Scalar)
other (Tensor or Scalar)
mruberry
left a comment
There was a problem hiding this comment.
I took another look and made some inline comments, I think the two significant issues are:
- we should implement the same behavior as PyTorch's current shift operators, think of these as the functional exposures for them
- we need to be careful with testing to cover all the relevant cases and document PyTorch's weird support for floating dtypes and incorrect type promotion behavior
Let me know if you have any questions, @asi1024
There was a problem hiding this comment.
here (and in all other __r*__ functions) you should not handle torch function twice, in the wrapper and in the function itself. Moreover, the wrapper is using the slowest variant (has_torch_function), you can create a faster wrapper that uses has_torch_function_variadic
f124606 to
787cdcb
Compare
|
@mruberry I updated the implementation and tests! Could you take another look? |
mruberry
left a comment
There was a problem hiding this comment.
This is awesome, @asi1024! This PR is very thorough and LGTM! Thank you for being so diligent with all the scalar x tensor and tensor x scalar cases. In the future I hope we can automatically generate those cases to simplify implementing operators.
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
@mruberry Thanks for your review! |
…ift__` (pytorch#59544) Summary: Fixes pytorch#58121 This PR implements `torch.bitwise_left_shift` and `torch.bitwise_right_shift` and `torch.Tensor.{__rlshift__/__rrshift__}`for compatibility with Python array API standard. (cc: mruberry, rgommers, emcastillo, kmaehashi) Pull Request resolved: pytorch#59544 Reviewed By: ngimel Differential Revision: D29348869 Pulled By: mruberry fbshipit-source-id: 329aee296cf890735e8a9f858bccfe87c03d06ca
Fixes #58121
This PR implements
torch.bitwise_left_shiftandtorch.bitwise_right_shiftandtorch.Tensor.{__rlshift__/__rrshift__}for compatibility with Python array API standard.(cc: @mruberry, @rgommers, @emcastillo, @kmaehashi)