Skip to content

ENH: Implement take_along_axis as described in #8708#8714

Closed
eric-wieser wants to merge 2 commits intonumpy:masterfrom
eric-wieser:take_along_axis
Closed

ENH: Implement take_along_axis as described in #8708#8714
eric-wieser wants to merge 2 commits intonumpy:masterfrom
eric-wieser:take_along_axis

Conversation

@eric-wieser
Copy link
Member

@eric-wieser eric-wieser commented Feb 28, 2017

Edit: Superceded by the simpler #11105

See #8708 and earlier issues linked there for discussion of the need for this function.

Let's keep discussion here to the implementation.

In future, it would be nice to implement this with npyiter in C code for speed, but this is a good starting point, and likely just as fast as what is currently being used in the wild.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First sentence should fit on a single line. Maybe

Take elements from slices indexed along the given axis.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is I need to disambiguate this from take, which is

Take elements from an array along an axis.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should go to the mailing list in any case, maybe take is not even the most natural word in the end, which would somewhat remove the problem ;). The take description kind of only works for 1-D, and in 1-D the two do the same thing, so its a bit of a twist :).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't have to be my example, but I do think it should be a single line.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're not wrong - I'm just asking for help in coming up with an unambiguous description under that constraint :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only other idea I have right now is to call it a vectorized take/pick (which bites with the vindex idea, but maybe that name is not great anyway, I think someone had suggest broadcasted index there too, which may be more logical anyway -- though that does not necessarily mean easier I guess, hehe).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unify use of notation for indices (e.g., here lowercase a, b; below upper case; ideally use the standard "integers", i.e., i--n). If possible, do use axis directly.

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm using the convention that a = range(0, A), ie the capital letters are the shape, and the lowercase ones indices for that dimension

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no way I can use axis directly here, short of some ascii art pointing to the middle index on the right

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could use out[i..., j..., k...] = arr[i..., indices[i..., j..., k...], k...], and then Ni, Nk, Nk further down?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the i,j,k, Ni,Nj,Nk, or perhaps i1, i2, i3, N1,N2,N3.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a fixup commit to apply this. I'll squash once everything else is approved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an example where one keeps the dimension?

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I think keeping dimensions is out of scope here, and belongs in #8710.

"keeping the dimension" is something that can be done either before or after take_along_axis.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle we could think about making keepdims (well, kind of the inverse) a kwarg here too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are lazy about a C version, too bad ;P

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seberg: What would it do though? Dimensions are already kept, in that out.ndim == indices.ndim.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there is no problem with argmax needing an expand dims, since we should just add a keepdims?

Copy link
Member Author

@eric-wieser eric-wieser Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - it's inconvenient right now to have to call np.expand_dims, but that's really a bug of not having keepdims in argmethods. And as for calling squeeze on the return value (vs what I proposed before) - you're likely in for a bad time if you start squeezing axes in the middle of your multidimensional data anyway.

Having assert(ind.ndim == x.ndim) still gives us plenty of freedom to come up with better semantics in future

Copy link
Member Author

@eric-wieser eric-wieser Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another example of how this restriction is sufficient (with keepdims) - indexing the brightest pixel:

assert img.shape == (400, 300, 3)
brightness = np.sum(img, axis=2, keepdims=True)  # again - always keep your dims!
argbrightrow = np.argmax(brightness, axis=1, keepdims=True)
brightest_by_row = np.take_along_axis(img, argbrightrow, axis=1)

Copy link
Member Author

@eric-wieser eric-wieser Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having assert(ind.ndim == x.ndim) still gives us plenty of freedom to come up with better semantics in future

We could remove this and add the default broadcasting behaviour, but this risks confusing users who try to do np.take_along_axis(a, a.argmax(axis=1),axis=1),(no keepdim) which would allow this to sometimes silently do the wrong thing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for now, I guess lets just say we don't put it in, and put it up for discussion on the mailing list. If anyone can present a good reason/usecase especially for the way you first described it, I am good with allowing it. At least the 0-D case (dim is missing), it seems rather intuitive after all....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not use ; in example code; just make another line...

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty par for the course for numpy docstrings - grep for >>> (\w+)\s*=.*;\s*\1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad practice anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a TypeError, I think (at least, [1, 2][1.] raises TypeError).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just copying np.array([1, 2])[np.array(1.)] here, which gives IndexError. This is just that error message, but without the bit about booleans.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bools are not considered ints here right?

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, np.issubdtype(np.bool_, np.integer) is false. There's a test for this error in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, should have tried with arrays; not very logical but best to stick with numpy practice here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure, also add the keepdims versions.

Copy link
Member Author

@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd argue for postponing that to #8710 as well, and then no special handling would be needed.
I have added a test to verify that expand_dims works before and after though, which is sort of the same thing

@eric-wieser
Copy link
Member Author

Does this last commit belong in this PR?

@mhvk
Copy link
Contributor

mhvk commented Feb 28, 2017

I think this is very nice. Generally, I think it is better to keep one PR to one logical commit, so in that sense the np.ma.median one should not be here, but it is nice to see an immediate use here. (In other words, either way is fine to me.)

Agreed though that this should be passed by the mailing list.

@eric-wieser eric-wieser force-pushed the take_along_axis branch 3 times, most recently from 1f93135 to 1296473 Compare March 9, 2017 19:55
@eric-wieser
Copy link
Member Author

eric-wieser commented Mar 9, 2017

put_along_axis is now in too (with the original semantics, not the proposed ones in above comments). One thing I noticed while doing that is that np.put is not a very good dual to np.take, as it has no axis argument (#8765), and has a peculiar "repeat as necessary" behaviour.

I'll send something out to the mailing list once my repeat confirmation email arrives and I actually remember to click on it during the 3-day window.

@homu
Copy link
Contributor

homu commented Mar 23, 2017

☔ The latest upstream changes (presumably #8795) made this pull request unmergeable. Please resolve the merge conflicts.

@homu
Copy link
Contributor

homu commented Mar 27, 2017

☔ The latest upstream changes (presumably #8847) made this pull request unmergeable. Please resolve the merge conflicts.

@homu
Copy link
Contributor

homu commented Apr 21, 2017

☔ The latest upstream changes (presumably #8886) made this pull request unmergeable. Please resolve the merge conflicts.

@charris
Copy link
Member

charris commented Apr 26, 2017

Could the reviewers involved here either sign off on it, merge it, or make a complaint.

@eric-wieser
Copy link
Member Author

@charris: The ball is in my court, I think - there was indecision about how to deal with broadcasting, with the suggestion of me consulting the mailing list - I have not done so.

I think it would be easier to design / put forth a case for this when argmin and argmax acquire keepdims arguments - so perhaps we should punt it to 1.14

@charris
Copy link
Member

charris commented Apr 26, 2017

OK, I'll punt. Thanks for the update.

@charris charris modified the milestones: 1.14.0 release, 1.13.0 release Apr 26, 2017
@homu
Copy link
Contributor

homu commented May 10, 2017

☔ The latest upstream changes (presumably #9050) made this pull request unmergeable. Please resolve the merge conflicts.

@charris
Copy link
Member

charris commented Oct 17, 2017

Still in abeyance. Should I punt this on to 1.15?

@seberg
Copy link
Member

seberg commented Oct 18, 2017

There is no much need for a specific milestone is there? I don't remember this, but for broadcasting, possibly we could do a minimal thing first that can be generalized later if it is too tricky to decide?

@charris
Copy link
Member

charris commented Oct 22, 2017

Removed the milestone.

@eric-wieser
Copy link
Member Author

I've rebased this just to avoid bitrot down the line, and moved the DOC commit to a separate pr (#9946). If nothing else, getting the doc change to existing functions into 1.14 makes it easier to pitch the new feature for 1.15.

Copy link
Member Author

@eric-wieser eric-wieser Nov 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the discussion #9946, I suppose this would be better described as

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
    for kk in ndindex(Nj):
        out[ii + s_[...,] + kk] = a[ii + s_[:,] + kk][indices[ii + s_[...,] + kk]]

Edit: updated

eric-wieser added a commit to eric-wieser/numpy that referenced this pull request Nov 22, 2017
@rzu512
Copy link

rzu512 commented Dec 5, 2017

Hi. I am quite new to git. How to see your implementation for take_along axis, after I git clone this repository?

I wanted to do similar thing for a tree reduction, and came up with an implementation for take_along_axis.

The shape of the result is same as shape of the index.

take_along_axis
def take_along_axis(arr, ind, axis):
    """Take elements from an array according to an index along axis.
    Parameters
    ----------
        arr : np.ndarray
        ind : np.ndarray
            Indexing of an array along an axis. This cannot be an int.
            For a 2D array arr and axis=1, this means: from each column of
            arr, select elements along the column by the indices at the
            corresponding column of ind. In other words, the ith column of
            the result is arr[ind[:, i], i].
        axis : int
    Returns
    -------
        result : np.ndarray
    Example
    -------
        >>> import numpy as np
        >>> arr = np.flip(np.arange(8).reshape(2, 4), axis=1)  # test array
        >>> arr
        array([[3, 2, 1, 0],
               [7, 6, 5, 4]])
        >>> ind = np.argsort(arr, axis=1)  # indexing along axis 1
        >>> result = take_along_axis(arr, ind, axis=1)
        >>> result
        array([[0, 1, 2, 3],
               [4, 5, 6, 7]])
        >>> answer = np.sort(arr, axis=1)
        >>> np.all(result == answer)
        True
    """
    # Does not check if axis or the ind are legal
    shape = arr.shape
    before = reduce(mul, shape[:axis], 1)
    at = shape[axis]
    after = arr.size // at // before
    a = arr.reshape(before, at, after)
    idx = [
        np.arange(before).reshape(before, 1, 1),
        ind.reshape(before, -1, after),
        np.arange(after).reshape(1, 1, after),
    ]
    return a[idx].reshape(ind.shape)

@eric-wieser
Copy link
Member Author

You'll need to add my fork of the repository (https://github.com/eric-wieser/numpy.git) as a remote, then you should just be able to checkout the take_along_axis.

Is your goal to compare my implementation with yours?

@rzu512
Copy link

rzu512 commented Dec 5, 2017

Thanks. I found it.

I mostly want to learn by looking at another solution.

Copy link
Contributor

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting eagerly for this to get in! Any reason why there's no activity?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add support for arr.ndim > indices.ndim?
It's a single line here (+ unit test + docs):

indices = indices.reshape((1, ) * (arr.ndim - indices.ndim) + indices.shape)

Use case / example:

x = np.array([[5, 3, 2, 8, 1],
              [0, 7, 1, 3, 2]])
# Completely arbitrary y = f(x0, x1, ..., xn), embarassingly parallel along axis=-1
# Here we only have x0, but we could have more.
y = x.sum(axis=0)
# Sort the x's, moving the ones that cause the smallest y's to the left
take_along_axis(x, np.argsort(y))

Copy link
Member Author

@eric-wieser eric-wieser May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think your use-case is well-motivated. A more explicit way to achieve that would be:

y = x.sum(axis=0, keepdims=True)
take_along_axis(x, np.argsort(y, axis=1), axis=1)

Copy link
Contributor

@crusaderky crusaderky May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole point is that with the one-liner addition f(x) can add or remove axes at will (some manual broadcasting required if it replaces or transposes axes, which however happens automatically if you e.g. wrap this in xarray.apply_ufunc).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in the comments above, we decide that perhaps it's best to not allow any case other than indices.ndim == arr.ndim, since there's no obvious right choice.

take_along_axis only really makes sense if you endeavor to keep all your axes aligned. xarray can probably solve that by axis names alone, but in numpy you need to indicate that by axis position. Therefore, you can't afford to let your axes collapse, and have numpy guess which one you lost: in your case, you're advocating for it to guess the left-most one should be reinserted - but this is only the case because you did sum(axis=0).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to axis=-1 like 99% of other numpy functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to axis=-1 like 99% of other numpy functions?

Copy link
Member Author

@eric-wieser eric-wieser May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish that were true. concatenate defaults to axis=0, and other functions default to axis=None. Since the axis is a key part of the function, it seems best just to require it.

This comment was marked as resolved.

This comment was marked as resolved.

This comment was marked as resolved.

@eric-wieser eric-wieser force-pushed the take_along_axis branch 2 times, most recently from 679a857 to 162ed85 Compare May 16, 2018 07:35
This is the reduced version that does not allow any insertion of extra dimensions
@eric-wieser
Copy link
Member Author

I've split up the commits into one supporting restricted and obvious broadcasting, and one that adds the less obvious behavior that matches apply_along_axis.

The first commit stands alone at #11105, and I think we should focus on getting minimal functionality in before trying to come up with non-obvious extensions.

@mhvk
Copy link
Contributor

mhvk commented May 29, 2018

superceded by #11105, which was merged

@mhvk mhvk closed this May 29, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants