ENH: allow start-stop array for indices in reduceat#25476
ENH: allow start-stop array for indices in reduceat#25476mhvk wants to merge 1 commit intonumpy:mainfrom
Conversation
2a6548b to
393ac30
Compare
|
@mhvk this seems to have slipped between the cracks, and now has a small merge conflict. Would you like to continue with it? |
Also introduces an initial argument, though it can only be used with the start-stop format.
393ac30 to
9e46997
Compare
|
Yes, I still think this would be useful. Unfortunately, the mailing list thread didn't generate much discussion... Anyway, I rebased and solved the merge conflict. |
|
I suppose I never actually reviewed it because the mailing list discussion was a bit slow/inconclusive. But overall, I think we should do something, I just never made up my mind what precisely :/. I was talking with @ikrommyd recently about this from an awkward array perspective. I am not actually sure it matches exactly, but yeah, we should maybe just do this. E.g. start-stop indices with non-inclusive logic would also be very interesting (i.e. split-points), or using the word "segmented" for a new method (but it seems that JAX |
|
Thanks for the ping @seberg. Since @mhvk appears to still be working on this it doesn't look like I need to pick this up. Happy to discuss the API of course. |
|
Thanks for that reference, that definition of I actually am inclined to a |
|
FWIW I also think a new |
|
Thanks, both! I think at the moment the main question is actually one of API. Personally, I remain somewhat in favour of just using or expanding Are you aware of other similar routines that might inform API? The cuda |
|
I am aware of this https://docs.pytorch.org/docs/stable/generated/torch.segment_reduce.html and https://www.tensorflow.org/api_docs/python/tf/math/segment_sum. In [57]: offsets = np.array([0, 3, 3, 5, 6])
In [58]: content = np.array([1, 2, 3, 4, 5, 6])
In [59]: parents = np.array([0, 0, 0, 2, 2, 3])
In [60]: jax.ops.segment_sum(content, parents)
Out[60]: Array([6, 0, 9, 6], dtype=int32)
In [61]: num_segments = parents.max() + 1
In [62]: result = np.zeros(num_segments, dtype=content.dtype)
In [63]: np.add.at(result, parents, content)
In [64]: result
Out[64]: array([6, 0, 9, 6])I think there's certain optimizations you can do if your segment ids are sorted (or you use offsets) given by jax's arg |
|
Yeah, I guess the question is really whether we like the tuple way of passing this and what we like as a name. (Yes, that means that the current I still think |
EDIT (2024-11-23): added tests and documentation, and move out of draft. I stuck to having start and stop be separate rows, so that one can easily pass
(start, stop)(as in the examples). I feel that is most logical given the present structure, where one can see a single array as a start array that implies a defaultstoparray. This can of course still be changed.Rationale
ufuncs have a
.reduceatmethod that allows having piecewise reductions, but using an array ofindicesthat is rather convoluted. Fromhttps://numpy.org/doc/stable/reference/generated/numpy.ufunc.reduceat.html, the
indicesare interpreted as follows:The exceptions are the main issue I have with the current definition: really, the current setup is only natural for contiguous pieces; for anything else, it requires contortion. For instance, the documentation describes how to get a
running sum as follows:
Note the slice at the end to remove the unwanted elements! And note that this omits the last set of 4 elements -- to get this, one has to add a solitary index 4 at the end - one cannot get slices that include the last element except as the last one.
The PR arose from this unnatural way to describe slices: Why can one not just pass in the start and stop values directly? With no exceptions, but just interpreted as slices should be. I.e., get a running sum as,
Currently, the updated docstring explains the new mode as follows:
The PR also adds a new
initialkeyword argument. The reason for this is that with the new layout I did not want to have the exception currently present, where ifstop < start, one gets the value at start. I felt it was more logical to treat this case as an empty reduction, but then it becomes necessary to able to pass in an initial value for reductions that do not have an identity, likenp.minimum(which of course just helps makereduceatmore similar toreduce).Note that I considered requiring
slice(start, stop), which might be clearer. I only did not do that since implementation-wise just having a tuple or an array with 2 columns was super easy. I also liked that with this implementation the old way could at least in principle be described in terms of the new one, as having a defaultstopthat just takes the next element ofstart(with the same exceptions as above). I ended not describing it as such in the docstring, though.Anyway, if in principle it is thought a good idea to make
reduceatmore flexible, the API is up for discussion. It could requireindices=slice(start, stop)(possibly step too), or one could allow not passing inindicesifstartandstopare present, or just add astopkeyword argument, whose defaults are interpreted as before.Links
(where I suggested adding a
sliceargument to reduce instead; also an option...)Old text
Triggered by #834 seeing some comments again, a draft just to see how it would look to allow
reduceatto take a set of start, stop indices (treated as slices), to make the interface a bit more easily comprehensible without making a truly new method. It also allows passing in aninitialto deal with empty slices.Fixes #834
Mostly to discuss whether we want this at all, and, if so, what the API should be. So probably best not to worry too much about implementation (the duplication of code, both in
reduceatitself and withreduceis large).Sample use:
Writing it out like this, I think a different order may be useful, i.e.,
np.add(a, [(1, 2), (3, -1), (5, 0)]). The reason I picked the other one was that I liked the idea of triggering it by usingslice(start, stop), with bothstartandstoppossibly arrays and a tuple of two lists was closer to that (although internally it just turns it into an array). The list of tuples suggests more a structured array with start and stop (and step?) entries.p.s. Fairly trivially extensible to
start, stop, step.