ENH: Implement take_along_axis as described in #8708#8714
ENH: Implement take_along_axis as described in #8708#8714eric-wieser wants to merge 2 commits intonumpy:masterfrom
Conversation
ece6ee1 to
626ec17
Compare
numpy/lib/shape_base.py
Outdated
There was a problem hiding this comment.
First sentence should fit on a single line. Maybe
Take elements from slices indexed along the given axis.
There was a problem hiding this comment.
The problem is I need to disambiguate this from take, which is
Take elements from an array along an axis.
There was a problem hiding this comment.
Should go to the mailing list in any case, maybe take is not even the most natural word in the end, which would somewhat remove the problem ;). The take description kind of only works for 1-D, and in 1-D the two do the same thing, so its a bit of a twist :).
There was a problem hiding this comment.
It doesn't have to be my example, but I do think it should be a single line.
There was a problem hiding this comment.
You're not wrong - I'm just asking for help in coming up with an unambiguous description under that constraint :)
There was a problem hiding this comment.
The only other idea I have right now is to call it a vectorized take/pick (which bites with the vindex idea, but maybe that name is not great anyway, I think someone had suggest broadcasted index there too, which may be more logical anyway -- though that does not necessarily mean easier I guess, hehe).
numpy/lib/shape_base.py
Outdated
There was a problem hiding this comment.
Unify use of notation for indices (e.g., here lowercase a, b; below upper case; ideally use the standard "integers", i.e., i--n). If possible, do use axis directly.
There was a problem hiding this comment.
Here I'm using the convention that a = range(0, A), ie the capital letters are the shape, and the lowercase ones indices for that dimension
There was a problem hiding this comment.
There's no way I can use axis directly here, short of some ascii art pointing to the middle index on the right
There was a problem hiding this comment.
I could use out[i..., j..., k...] = arr[i..., indices[i..., j..., k...], k...], and then Ni, Nk, Nk further down?
There was a problem hiding this comment.
I like the i,j,k, Ni,Nj,Nk, or perhaps i1, i2, i3, N1,N2,N3.
There was a problem hiding this comment.
I've added a fixup commit to apply this. I'll squash once everything else is approved
numpy/lib/shape_base.py
Outdated
There was a problem hiding this comment.
Could you add an example where one keeps the dimension?
There was a problem hiding this comment.
Again, I think keeping dimensions is out of scope here, and belongs in #8710.
"keeping the dimension" is something that can be done either before or after take_along_axis.
There was a problem hiding this comment.
In principle we could think about making keepdims (well, kind of the inverse) a kwarg here too.
There was a problem hiding this comment.
So you are lazy about a C version, too bad ;P
There was a problem hiding this comment.
@seberg: What would it do though? Dimensions are already kept, in that out.ndim == indices.ndim.
There was a problem hiding this comment.
But there is no problem with argmax needing an expand dims, since we should just add a keepdims?
There was a problem hiding this comment.
Right - it's inconvenient right now to have to call np.expand_dims, but that's really a bug of not having keepdims in argmethods. And as for calling squeeze on the return value (vs what I proposed before) - you're likely in for a bad time if you start squeezing axes in the middle of your multidimensional data anyway.
Having assert(ind.ndim == x.ndim) still gives us plenty of freedom to come up with better semantics in future
There was a problem hiding this comment.
Another example of how this restriction is sufficient (with keepdims) - indexing the brightest pixel:
assert img.shape == (400, 300, 3)
brightness = np.sum(img, axis=2, keepdims=True) # again - always keep your dims!
argbrightrow = np.argmax(brightness, axis=1, keepdims=True)
brightest_by_row = np.take_along_axis(img, argbrightrow, axis=1)
There was a problem hiding this comment.
Having assert(ind.ndim == x.ndim) still gives us plenty of freedom to come up with better semantics in future
We could remove this and add the default broadcasting behaviour, but this risks confusing users who try to do np.take_along_axis(a, a.argmax(axis=1),axis=1),(no keepdim) which would allow this to sometimes silently do the wrong thing
There was a problem hiding this comment.
So for now, I guess lets just say we don't put it in, and put it up for discussion on the mailing list. If anyone can present a good reason/usecase especially for the way you first described it, I am good with allowing it. At least the 0-D case (dim is missing), it seems rather intuitive after all....
numpy/lib/shape_base.py
Outdated
There was a problem hiding this comment.
I would not use ; in example code; just make another line...
There was a problem hiding this comment.
This is pretty par for the course for numpy docstrings - grep for >>> (\w+)\s*=.*;\s*\1
numpy/lib/shape_base.py
Outdated
There was a problem hiding this comment.
This should be a TypeError, I think (at least, [1, 2][1.] raises TypeError).
There was a problem hiding this comment.
Just copying np.array([1, 2])[np.array(1.)] here, which gives IndexError. This is just that error message, but without the bit about booleans.
There was a problem hiding this comment.
Bools are not considered ints here right?
There was a problem hiding this comment.
Correct, np.issubdtype(np.bool_, np.integer) is false. There's a test for this error in this PR.
There was a problem hiding this comment.
OK, should have tried with arrays; not very logical but best to stick with numpy practice here.
numpy/lib/tests/test_shape_base.py
Outdated
There was a problem hiding this comment.
Just to be sure, also add the keepdims versions.
There was a problem hiding this comment.
I'd argue for postponing that to #8710 as well, and then no special handling would be needed.
I have added a test to verify that expand_dims works before and after though, which is sort of the same thing
|
Does this last commit belong in this PR? |
|
I think this is very nice. Generally, I think it is better to keep one PR to one logical commit, so in that sense the Agreed though that this should be passed by the mailing list. |
1f93135 to
1296473
Compare
|
I'll send something out to the mailing list once my repeat confirmation email arrives and I actually remember to click on it during the 3-day window. |
|
☔ The latest upstream changes (presumably #8795) made this pull request unmergeable. Please resolve the merge conflicts. |
1296473 to
9242bce
Compare
|
☔ The latest upstream changes (presumably #8847) made this pull request unmergeable. Please resolve the merge conflicts. |
9242bce to
a2e68cc
Compare
|
☔ The latest upstream changes (presumably #8886) made this pull request unmergeable. Please resolve the merge conflicts. |
|
Could the reviewers involved here either sign off on it, merge it, or make a complaint. |
|
@charris: The ball is in my court, I think - there was indecision about how to deal with broadcasting, with the suggestion of me consulting the mailing list - I have not done so. I think it would be easier to design / put forth a case for this when |
|
OK, I'll punt. Thanks for the update. |
|
☔ The latest upstream changes (presumably #9050) made this pull request unmergeable. Please resolve the merge conflicts. |
|
Still in abeyance. Should I punt this on to 1.15? |
|
There is no much need for a specific milestone is there? I don't remember this, but for broadcasting, possibly we could do a minimal thing first that can be generalized later if it is too tricky to decide? |
|
Removed the milestone. |
3f44cd1 to
8b7c244
Compare
|
I've rebased this just to avoid bitrot down the line, and moved the DOC commit to a separate pr (#9946). If nothing else, getting the doc change to existing functions into 1.14 makes it easier to pitch the new feature for 1.15. |
numpy/lib/shape_base.py
Outdated
There was a problem hiding this comment.
Based on the discussion #9946, I suppose this would be better described as
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
for kk in ndindex(Nj):
out[ii + s_[...,] + kk] = a[ii + s_[:,] + kk][indices[ii + s_[...,] + kk]]Edit: updated
Extracted from numpygh-8714 [ci-skip]
|
Hi. I am quite new to I wanted to do similar thing for a tree reduction, and came up with an implementation for The shape of the result is same as shape of the index. take_along_axisdef take_along_axis(arr, ind, axis):
"""Take elements from an array according to an index along axis.
Parameters
----------
arr : np.ndarray
ind : np.ndarray
Indexing of an array along an axis. This cannot be an int.
For a 2D array arr and axis=1, this means: from each column of
arr, select elements along the column by the indices at the
corresponding column of ind. In other words, the ith column of
the result is arr[ind[:, i], i].
axis : int
Returns
-------
result : np.ndarray
Example
-------
>>> import numpy as np
>>> arr = np.flip(np.arange(8).reshape(2, 4), axis=1) # test array
>>> arr
array([[3, 2, 1, 0],
[7, 6, 5, 4]])
>>> ind = np.argsort(arr, axis=1) # indexing along axis 1
>>> result = take_along_axis(arr, ind, axis=1)
>>> result
array([[0, 1, 2, 3],
[4, 5, 6, 7]])
>>> answer = np.sort(arr, axis=1)
>>> np.all(result == answer)
True
"""
# Does not check if axis or the ind are legal
shape = arr.shape
before = reduce(mul, shape[:axis], 1)
at = shape[axis]
after = arr.size // at // before
a = arr.reshape(before, at, after)
idx = [
np.arange(before).reshape(before, 1, 1),
ind.reshape(before, -1, after),
np.arange(after).reshape(1, 1, after),
]
return a[idx].reshape(ind.shape) |
|
You'll need to add my fork of the repository (https://github.com/eric-wieser/numpy.git) as a remote, then you should just be able to checkout the Is your goal to compare my implementation with yours? |
|
Thanks. I found it. I mostly want to learn by looking at another solution. |
crusaderky
left a comment
There was a problem hiding this comment.
Waiting eagerly for this to get in! Any reason why there's no activity?
numpy/lib/shape_base.py
Outdated
There was a problem hiding this comment.
Could you add support for arr.ndim > indices.ndim?
It's a single line here (+ unit test + docs):
indices = indices.reshape((1, ) * (arr.ndim - indices.ndim) + indices.shape)
Use case / example:
x = np.array([[5, 3, 2, 8, 1],
[0, 7, 1, 3, 2]])
# Completely arbitrary y = f(x0, x1, ..., xn), embarassingly parallel along axis=-1
# Here we only have x0, but we could have more.
y = x.sum(axis=0)
# Sort the x's, moving the ones that cause the smallest y's to the left
take_along_axis(x, np.argsort(y))
There was a problem hiding this comment.
I don't think your use-case is well-motivated. A more explicit way to achieve that would be:
y = x.sum(axis=0, keepdims=True)
take_along_axis(x, np.argsort(y, axis=1), axis=1)There was a problem hiding this comment.
The whole point is that with the one-liner addition f(x) can add or remove axes at will (some manual broadcasting required if it replaces or transposes axes, which however happens automatically if you e.g. wrap this in xarray.apply_ufunc).
There was a problem hiding this comment.
Note that in the comments above, we decide that perhaps it's best to not allow any case other than indices.ndim == arr.ndim, since there's no obvious right choice.
take_along_axis only really makes sense if you endeavor to keep all your axes aligned. xarray can probably solve that by axis names alone, but in numpy you need to indicate that by axis position. Therefore, you can't afford to let your axes collapse, and have numpy guess which one you lost: in your case, you're advocating for it to guess the left-most one should be reinserted - but this is only the case because you did sum(axis=0).
numpy/lib/shape_base.py
Outdated
There was a problem hiding this comment.
default to axis=-1 like 99% of other numpy functions?
numpy/lib/shape_base.py
Outdated
There was a problem hiding this comment.
default to axis=-1 like 99% of other numpy functions?
There was a problem hiding this comment.
I wish that were true. concatenate defaults to axis=0, and other functions default to axis=None. Since the axis is a key part of the function, it seems best just to require it.
doc/release/1.14.0-notes.rst
Outdated
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
numpy/lib/shape_base.py
Outdated
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
numpy/lib/shape_base.py
Outdated
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
679a857 to
162ed85
Compare
This is the reduced version that does not allow any insertion of extra dimensions
…h apply_along_axis
a5cc638 to
4eec0ce
Compare
|
I've split up the commits into one supporting restricted and obvious broadcasting, and one that adds the less obvious behavior that matches The first commit stands alone at #11105, and I think we should focus on getting minimal functionality in before trying to come up with non-obvious extensions. |
|
superceded by #11105, which was merged |
Edit: Superceded by the simpler #11105
See #8708 and earlier issues linked there for discussion of the need for this function.
Let's keep discussion here to the implementation.
In future, it would be nice to implement this with npyiter in C code for speed, but this is a good starting point, and likely just as fast as what is currently being used in the wild.