Add basic ldexp operator for numpy compatibility#45370
Add basic ldexp operator for numpy compatibility#45370ranman wants to merge 1 commit intopytorch:masterfrom
Conversation
💊 CI failures summary and remediationsAs of commit 83ab935 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
🚧 2 fixed upstream failures:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
|
Added @zou3519 for NamedRegistrations.cpp |
|
Hey @ranman, this is cool and it looks pretty good! I have a suggestion for the docs and the test, and a question about whether this should be implemented as a composite op (using existing torch functions) or a custom kernel. Looking forward to hearing your thoughts. |
|
I pushed a new update here based on implementing ldexp as a composite operator. Also, I'm doubtful of the need to include ldexp as an operator. We did some research on github and in a few papers but could not find any usage of it in modern history. Do we really need this op? |
It does appear to be very little used.
This is an excellent question. The request for this particular op was motivated by it being in other frameworks like NumPy (https://numpy.org/doc/1.18/reference/generated/numpy.ldexp.html?highlight=ldexp#numpy.ldexp) and JAX (https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ldexp.html), although this is double counting NumPy since JAX is just copying NumPy. We haven't been especially picky when it comes to which NumPy functions we're implementing as part of our NumPy Compatibility effort (we do avoid those that are deprecated or don't translate well to PyTorch), nor have we carefully reviewed each function for how often it's used. Simpler ops like ldexp also have a useful didactic function. On net I would prefer we have it since the cost of its inclusion is very low and it lets us tell a simpler NumPy Compat story that isn't caveated by our evaluation of how often a function is used. |
There was a problem hiding this comment.
This is no longer true because the composite ops will handle the type promotion for you.
There was a problem hiding this comment.
Tensor result;
If you create the tensor like this it will have self's dtype, but now that the op is implemented as a composite it inherits mul's type promotion, so the result dtype may not be self's.
There was a problem hiding this comment.
If I don't initialize result at all i get an error like this:
RuntimeError: Expected a Tensor of type Variable but found an undefined Tensor for argument #0 'out' if using at::mul_out
Any idea how to properly create result?
There was a problem hiding this comment.
Oh right, sorry, I always forget we have this wonky issue with TensorIterator. This is also why mul and mul_out have divergent implementations:
pytorch/aten/src/ATen/native/BinaryOps.cpp
Line 243 in ff0af72
My mistake on that. You can just write this as return self * 2 ** other here, too. Also I think the checks for complex can be dropped now, too, since we support complex multiplication (and if we didn't we could let mul handle the error).
There was a problem hiding this comment.
I tried using just self * at::pow(2.0, other) and hit some type casting errors - want to take a look at what I ended up with and let me know if we're good to go?
Tensor& ldexp_out(Tensor& result, const Tensor& self, const Tensor& other) {
return at::mul_out(result, self, at::pow(2.0, other));
}
Tensor ldexp(const Tensor& self, const Tensor& other) {
return at::mul(self, at::pow(2.0, other));
}
Tensor& ldexp_(Tensor& self, const Tensor& other) {
return at::ldexp_out(self, self, other);
}There was a problem hiding this comment.
This looks correct to me -- what were the type casting errors?
There was a problem hiding this comment.
Oh it was the implementation with literal * instead mul - the implementation I wrote above is the one that I ended up going with
There was a problem hiding this comment.
Oh, right. Got it. Thanks for explaining that to me again ;)
fb0308f to
94593ae
Compare
There was a problem hiding this comment.
Failures are because you need an ldexp_ entry here.
There was a problem hiding this comment.
Ok I'll fix that! I promise my next operator will be less overhead lol.
There was a problem hiding this comment.
This is great stuff, don't be discouraged. We've known for awhile that adding an operator to PyTorch is complicated with a lot of bookkeeping. This is one of the simplest ops that can be added, too. Other operations, like the split operations, require modifying several more files and tests.
We should seriously consider investing in better developer documentation for how to add an operator.
ad61311 to
6b286ff
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ranman has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
8464754 to
1473ce9
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ranman has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ranman has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
breaks pytorch_windows_vs2019_py36_cuda10.1_test2 build? |
Summary: Adds ldexp operator for pytorch#38349 I'm not entirely sure the changes to `NamedRegistrations.cpp` were needed but I saw other operators in there so I added it. Normally the ldexp operator is used along with the frexp to construct and deconstruct floating point values. This is useful for performing operations on either the mantissa and exponent portions of floating point values. Sleef, std math.h, and cuda support both ldexp and frexp but not for all data types. I wasn't able to figure out how to get the iterators to play nicely with a vectorized kernel so I have left this with just the normal CPU kernel for now. This is the first operator I'm adding so please review with an eye for errors. Pull Request resolved: pytorch#45370 Reviewed By: mruberry Differential Revision: D24333516 Pulled By: ranman fbshipit-source-id: 2df78088f00aa9789aae1124eda399771e120d3f
Adds ldexp operator for #38349
I'm not entirely sure the changes to
NamedRegistrations.cppwere needed but I saw other operators in there so I added it.Normally the ldexp operator is used along with the frexp to construct and deconstruct floating point values. This is useful for performing operations on either the mantissa and exponent portions of floating point values.
Sleef, std math.h, and cuda support both ldexp and frexp but not for all data types. I wasn't able to figure out how to get the iterators to play nicely with a vectorized kernel so I have left this with just the normal CPU kernel for now.
This is the first operator I'm adding so please review with an eye for errors.