Skip to content

Add basic ldexp operator for numpy compatibility#45370

Closed
ranman wants to merge 1 commit intopytorch:masterfrom
ranman:feat/ldexp
Closed

Add basic ldexp operator for numpy compatibility#45370
ranman wants to merge 1 commit intopytorch:masterfrom
ranman:feat/ldexp

Conversation

@ranman
Copy link
Copy Markdown

@ranman ranman commented Sep 26, 2020

Adds ldexp operator for #38349

I'm not entirely sure the changes to NamedRegistrations.cpp were needed but I saw other operators in there so I added it.

Normally the ldexp operator is used along with the frexp to construct and deconstruct floating point values. This is useful for performing operations on either the mantissa and exponent portions of floating point values.

Sleef, std math.h, and cuda support both ldexp and frexp but not for all data types. I wasn't able to figure out how to get the iterators to play nicely with a vectorized kernel so I have left this with just the normal CPU kernel for now.

This is the first operator I'm adding so please review with an eye for errors.

@ranman ranman requested review from colesbury and ezyang September 26, 2020 01:17
@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Sep 26, 2020

💊 CI failures summary and remediations

As of commit 83ab935 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🚧 2 fixed upstream failures:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

Since your merge base is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 60 times.

@ezyang ezyang requested a review from zou3519 September 28, 2020 15:46
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 28, 2020

Added @zou3519 for NamedRegistrations.cpp

@ezyang ezyang requested a review from mruberry September 28, 2020 15:55
Comment thread tools/autograd/derivatives.yaml Outdated
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
@ranman ranman requested review from zou3519 and removed request for ezyang September 28, 2020 18:59
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Comment thread aten/src/ATen/native/native_functions.yaml Outdated
Comment thread docs/source/tensors.rst Outdated
Comment thread aten/src/ATen/native/cpu/BinaryOpsKernel.cpp Outdated
Comment thread test/test_torch.py Outdated
Comment thread test/test_torch.py Outdated
Comment thread torch/_tensor_docs.py Outdated
Comment thread torch/_torch_docs.py Outdated
@mruberry
Copy link
Copy Markdown
Collaborator

Hey @ranman, this is cool and it looks pretty good!

I have a suggestion for the docs and the test, and a question about whether this should be implemented as a composite op (using existing torch functions) or a custom kernel. Looking forward to hearing your thoughts.

@ranman
Copy link
Copy Markdown
Author

ranman commented Oct 13, 2020

I pushed a new update here based on implementing ldexp as a composite operator.

Also, I'm doubtful of the need to include ldexp as an operator. We did some research on github and in a few papers but could not find any usage of it in modern history.

Do we really need this op?

@mruberry
Copy link
Copy Markdown
Collaborator

Also, I'm doubtful of the need to include ldexp as an operator. We did some research on github and in a few papers but could not find any usage of it in modern history.

It does appear to be very little used.

Do we really need this op?

This is an excellent question. The request for this particular op was motivated by it being in other frameworks like NumPy (https://numpy.org/doc/1.18/reference/generated/numpy.ldexp.html?highlight=ldexp#numpy.ldexp) and JAX (https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ldexp.html), although this is double counting NumPy since JAX is just copying NumPy. We haven't been especially picky when it comes to which NumPy functions we're implementing as part of our NumPy Compatibility effort (we do avoid those that are deprecated or don't translate well to PyTorch), nor have we carefully reviewed each function for how often it's used. Simpler ops like ldexp also have a useful didactic function.

On net I would prefer we have it since the cost of its inclusion is very low and it lets us tell a simpler NumPy Compat story that isn't caveated by our evaluation of how often a function is used.

Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer true because the composite ops will handle the type promotion for you.

Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor result;

If you create the tensor like this it will have self's dtype, but now that the op is implemented as a composite it inherits mul's type promotion, so the result dtype may not be self's.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't initialize result at all i get an error like this:
RuntimeError: Expected a Tensor of type Variable but found an undefined Tensor for argument #0 'out' if using at::mul_out

Any idea how to properly create result?

Copy link
Copy Markdown
Collaborator

@mruberry mruberry Oct 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, sorry, I always forget we have this wonky issue with TensorIterator. This is also why mul and mul_out have divergent implementations:

Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {

My mistake on that. You can just write this as return self * 2 ** other here, too. Also I think the checks for complex can be dropped now, too, since we support complex multiplication (and if we didn't we could let mul handle the error).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using just self * at::pow(2.0, other) and hit some type casting errors - want to take a look at what I ended up with and let me know if we're good to go?

Tensor& ldexp_out(Tensor& result, const Tensor& self, const Tensor& other) {
  return at::mul_out(result, self, at::pow(2.0, other));
}

Tensor ldexp(const Tensor& self, const Tensor& other) {
  return at::mul(self, at::pow(2.0, other));
}

Tensor& ldexp_(Tensor& self, const Tensor& other) {
  return at::ldexp_out(self, self, other);
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks correct to me -- what were the type casting errors?

Copy link
Copy Markdown
Author

@ranman ranman Oct 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it was the implementation with literal * instead mul - the implementation I wrote above is the one that I ended up going with

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right. Got it. Thanks for explaining that to me again ;)

Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Comment thread test/test_torch.py Outdated
Comment thread torch/_torch_docs.py Outdated
Comment thread torch/_torch_docs.py Outdated
Comment thread torch/_torch_docs.py Outdated
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ranman!

This looks very good. I made a few additional comments and then I think this will be good to go! I'm looking forward to having ldexp. I forgot Python has it, too. Even Torchscript has it!

@ranman ranman force-pushed the feat/ldexp branch 2 times, most recently from fb0308f to 94593ae Compare October 15, 2020 02:25
Comment thread torch/_tensor_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failures are because you need an ldexp_ entry here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I'll fix that! I promise my next operator will be less overhead lol.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great stuff, don't be discouraged. We've known for awhile that adding an operator to PyTorch is complicated with a lot of bookkeeping. This is one of the simplest ops that can be added, too. Other operations, like the split operations, require modifying several more files and tests.

We should seriously consider investing in better developer documentation for how to add an operator.

@ranman ranman force-pushed the feat/ldexp branch 2 times, most recently from ad61311 to 6b286ff Compare October 15, 2020 14:41
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ranman has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ranman ranman force-pushed the feat/ldexp branch 2 times, most recently from 8464754 to 1473ce9 Compare November 5, 2020 23:56
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ranman has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ranman has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry mruberry self-requested a review November 20, 2020 09:53
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Thanks @ranman!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ranman merged this pull request in 562d4c3.

@ilia-cher
Copy link
Copy Markdown
Contributor

breaks pytorch_windows_vs2019_py36_cuda10.1_test2 build?
https://ezyang.github.io/pytorch-ci-hud/build/pytorch-master

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Nov 20, 2020

Have one PR that disables test on Windows: #48334
And another that attempts to fix it: #48335

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Adds ldexp operator for pytorch#38349

I'm not entirely sure the changes to `NamedRegistrations.cpp` were needed but I saw other operators in there so I added it.

Normally the ldexp operator is used along with the frexp to construct and deconstruct floating point values. This is useful for performing operations on either the mantissa and exponent portions of floating point values.

Sleef, std math.h, and cuda support both ldexp and frexp but not for all data types. I wasn't able to figure out how to get the iterators to play nicely with a vectorized kernel so I have left this with just the normal CPU kernel for now.

This is the first operator I'm adding so please review with an eye for errors.

Pull Request resolved: pytorch#45370

Reviewed By: mruberry

Differential Revision: D24333516

Pulled By: ranman

fbshipit-source-id: 2df78088f00aa9789aae1124eda399771e120d3f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants