Skip to content

Support torch.bitwise_{left/right}_shift and __rlshift__, __rrshift__#59544

Closed
asi1024 wants to merge 11 commits intopytorch:masterfrom
asi1024:bitwise-shift
Closed

Support torch.bitwise_{left/right}_shift and __rlshift__, __rrshift__#59544
asi1024 wants to merge 11 commits intopytorch:masterfrom
asi1024:bitwise-shift

Conversation

@asi1024
Copy link
Copy Markdown
Contributor

@asi1024 asi1024 commented Jun 7, 2021

Fixes #58121

This PR implements torch.bitwise_left_shift and torch.bitwise_right_shift and torch.Tensor.{__rlshift__/__rrshift__}for compatibility with Python array API standard.
(cc: @mruberry, @rgommers, @emcastillo, @kmaehashi)

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jun 7, 2021

💊 CI failures summary and remediations

As of commit e793a3c (more details on the Dr. CI page and at hud.pytorch.org/pr/59544):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mruberry mruberry self-requested a review June 7, 2021 11:19
@ejguan ejguan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 7, 2021
@rgommers rgommers added the module: python array api Issues related to the Python Array API label Jun 7, 2021
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to make the definition structured (67274f0), but the build fails at BinaryShiftOpsKernels.cu.

/home/asi1024/pytorch/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu(57): error: no instance of constructor "at::native::<unnamed>::RegisterCUDADispatch<FnPtr, T>::RegisterCUDADispatch [with FnPtr=void (*)(at::TensorIterator &), T=at::native::lshift_stub]" matches the argument list
            argument types are: (at::native::lshift_stub, void (*)(at::TensorIterator &))

1 error detected in the compilation of "/home/asi1024/pytorch/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu".
CMake Error at torch_cuda_generated_BinaryShiftOpsKernels.cu.o.Release.cmake:281 (message):
  Error generating file
  /home/asi1024/pytorch/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/./torch_cuda_generated_BinaryShiftOpsKernels.cu.o

Could you please help me fix this error?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lshift_stub needs to take a TensorIteratorBase& not TensorIterator&

@asi1024 asi1024 force-pushed the bitwise-shift branch 2 times, most recently from fdf2490 to 67274f0 Compare June 9, 2021 04:25
Copy link
Copy Markdown
Collaborator

@mruberry mruberry Jun 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool. We'll also need to add a few tests to test_binary_ufuncs.py:

  • compare bitwise_left_shift and bitwise_right_shift tensor x tensor, tensor x scalar, and scalar x tensor variants against NumPy's left_shift and right_shift as references
  • a sanity check that operator left shift works for tensor x tensor, tensor x scalar, and scalar x tensor, too
  • check that type promotion works as expected, with the result dtype of the operation being the same dtype as the input tensor

This is required because OpInfos don't automatically test that the operation implements the same behavior as a reference except for UnaryUfuncInfos (although we're expanding this functionality soon to more OpInfos), OpInfos don't support operators (although we'll probably add that functionality soon, too!), and OpInfos don't support type promotion testing (although we're working on adding that soon, too!).

Comment thread torch/_torch_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting because the OpInfo shows that floating point types works, and testing in PyTorch also shows that left and right shift work with floats. What about adding a check that both input and other are integer types in the code? Then this sentence would be updated to:

"Both :attr:`input` and :attr:`other` must have integer dtypes."

It can be its own paragraph after the first sentence.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another important here, however, is that this function does not participate in type promotion. The returned dtype should be the same as input's dtype. So I guess this could say:

"Both :attr:input and :attr:other must have integer dtypes. The result will have the same dtype as :attr:`input`."

Comment thread torch/_torch_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a left shift it doesn't matter whether these are logical or arithmetic shifts (https://en.wikipedia.org/wiki/Arithmetic_shift#Equivalence_of_arithmetic_and_logical_left_shifts_and_multiplication), but for a right shift it does matter.

Is this a logical or an arithmetic shift? We should be clear to document it. Note I think the Python Array API underspecifies its "bitwise_right_shift", actually. cc @rgommers

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The <<, >> operators, numpy and pytorch are all consistent with each it looks like:

>>> 10 << 1
20
>>> 10 >> 1
5
>>> -10 >> 1
-5
>>> -11 >> 1
-6

>>> t = torch.Tensor([10])
>>> t
tensor([10.])
>>> t << 1
tensor([20.])
>>> t >> 1
tensor([5.])
>>> t = torch.Tensor([-10])
>>> t << 1
tensor([-20.])
>>> t >> 1
tensor([-5.])

>>> def f(x, shift=1):
...     x2 = np.right_shift(x, shift)
...     print(f'Original ({x}):  {np.binary_repr(x)}')
...     print(f'Shifted ({x2}): {np.binary_repr(x2)}')
... 
... 
>>> f(10)
Original (10):  1010
Shifted (5): 101
>>> f(-10)
Original (-10):  -1010
Shifted (-5): -101

IIRC this is arithmetic shifting for negative integers. I'll open an issue to double check and clarify this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is resolved in the specification now, it says to use arithmetic shift.

Comment thread torch/_torch_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing the out argument: (input, other, *, out)

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @asi1024!

Thanks for this PR. I made a few comments inline. We should be very careful to match the Python Array API's specified behavior, and I suggest adding a few more tests. Also, we should be very precise about specifying the kind of shift done (logical or arithmetic), especially for the right shift where these shifts are distinct from each other.

@asi1024
Copy link
Copy Markdown
Contributor Author

asi1024 commented Jun 11, 2021

@mruberry Thank you for your review!
I fixed the bitwise shift operations to prohibit non-integer arguments and added tests of the compatibility from NumPy's behavior. PTAL!

Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's talk about this behavior and the current shift functions.

PyTorch defines the following shift functions today: lshift, ilshift, rshift, irshift. And these and the current PR have the same type promotion bug, as shown by this sample program:

t = torch.tensor((1, 2, 3))
t << 3.2
: tensor([ 8, 16, 24])

The current shift functions and these new ones also have divergent dtype support, as shown here:

t = torch.tensor((1., 2, 3))
t << 3.2
: tensor([ 9.1896, 18.3792, 27.5688])

This creates a tricky situation. Here's what I'd like to propose we do:

  • the new functions should be implemented as aliases of the existing shift functions plus the new reflective shifts (they should just call the existing functions)
  • the incorrect type promotion behavior should be preserved, because changing it would be BC-breaking

This will preserve float support, which would also require a BC-breaking change, but it will ensure that PyTorch's shift functionality is consistent.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry I agree to support floating-point type arguments to make the behavior consistent with existing __lshift__, and will fix the implementation!
However, __lshift__ and __ilshift__ don't have an interface for bitwise_left_shift_out, so the new functions cannot be implemented as aliases of existing functions. So I am planning to support bitwise_left_shift_out that calls lshift_stub. Is it OK?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, you're right. Sorry, I didn't mean to suggest that we not add any new code, just that we be sure to reuse the same behavior (and its natural extension)

Comment thread test/test_binary_ufuncs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create "test_shift" and do the following:

  • use the @dtypes decorator and select dtypes torch.uint8, torch.int64
  • enumerate pairs (torch_op, numpy_op) that define the functions to test
  • the pairs should be (torch.bitwise_left_shift, np.left_shift), (operator.left_shift, operator.left_shift), (torch.bitwise_right_shift, np.bitwise_right_shift), (operator.right_shift, operator.right_shift)
  • compare the results on a variety of test cases that use the test's device and dtype
  • don't bother testing the out or inplace variants, those are tested automatically when creating an opinfo for torch.bitwise_left_shift and torch.bitwise_right shift

See https://docs.python.org/3/library/operator.html for more on using the operator module.

fyi @pmeier, this highlights the importance of automated binary elementwise testing, and @kshitij12345, I expect we'll increasingly be interested in supporting operators using OpInfos (operators are like aliases in that operator.left_shift is an alias for torch.bitwise_left_shift, but there are no method variants for operators and the inplace operator variants are different, because the inplace version of the operator is operator.ilshift, so we'll need to think of the appropriate way to model them -- maybe as totally different ops with their own OpInfos?)

FYIs aside, we're not done with testing yet, @asi1024! We also need to add a test that validates PyTorch's type promotion behavior and support for floating point inputs. For this test we can just "spot check" one input for each case we want to validate. In particular, we should check:

  • int tensor x float scalar == int tensor x trunc(float scalar) type promotion
  • float tensor x int scalar works as expected (it's OK to use a 'golden value' here)
  • float tensor x float scalar works as expected (it's OK to use a 'golden value' here)

Does that make sense, @asi1024? I know that's a lot, but we should be very careful validating the behavior here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the tests in c350273!

Comment thread torch/_tensor_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These entries need to be added to tensors.rst, too, so they appear in the docs

Comment thread torch/_torch_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These entries need to be added to torch.rst to appear in the docs

Comment thread torch/_torch_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will no longer be true once the current left shift implementation is used. What actually happens with floating point inputs? We should probably document the behavior using cases ("If input has an integer dtype... If input has a floating point dtype..."). We should be sure to document the odd behavior that int tensor x floating point scalar inputs have their floating point scalar values cast to the tensor's dtype.

Comment thread torch/_torch_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input (Tensor or Scalar)
other (Tensor or Scalar)

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took another look and made some inline comments, I think the two significant issues are:

  • we should implement the same behavior as PyTorch's current shift operators, think of these as the functional exposures for them
  • we need to be careful with testing to cover all the relevant cases and document PyTorch's weird support for floating dtypes and incorrect type promotion behavior

Let me know if you have any questions, @asi1024

Comment thread torch/_tensor.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here (and in all other __r*__ functions) you should not handle torch function twice, in the wrapper and in the function itself. Moreover, the wrapper is using the slowest variant (has_torch_function), you can create a faster wrapper that uses has_torch_function_variadic

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e2df19d. Thanks!

@asi1024 asi1024 force-pushed the bitwise-shift branch 3 times, most recently from f124606 to 787cdcb Compare June 17, 2021 22:10
@asi1024
Copy link
Copy Markdown
Contributor Author

asi1024 commented Jun 18, 2021

@mruberry I updated the implementation and tests! Could you take another look?

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, @asi1024! This PR is very thorough and LGTM! Thank you for being so diligent with all the scalar x tensor and tensor x scalar cases. In the future I hope we can automatically generate those cases to simplify implementing operators.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in 26cdec6.

@asi1024
Copy link
Copy Markdown
Contributor Author

asi1024 commented Jun 25, 2021

@mruberry Thanks for your review!

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
…ift__` (pytorch#59544)

Summary:
Fixes pytorch#58121

This PR implements `torch.bitwise_left_shift` and `torch.bitwise_right_shift` and `torch.Tensor.{__rlshift__/__rrshift__}`for compatibility with Python array API standard.
(cc: mruberry, rgommers, emcastillo, kmaehashi)

Pull Request resolved: pytorch#59544

Reviewed By: ngimel

Differential Revision: D29348869

Pulled By: mruberry

fbshipit-source-id: 329aee296cf890735e8a9f858bccfe87c03d06ca
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: python array api Issues related to the Python Array API open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support __rlshift__ and __rrshift__

8 participants