Skip to content

ENH: allow start-stop array for indices in reduceat#25476

Open
mhvk wants to merge 1 commit intonumpy:mainfrom
mhvk:reduceat-start-stop
Open

ENH: allow start-stop array for indices in reduceat#25476
mhvk wants to merge 1 commit intonumpy:mainfrom
mhvk:reduceat-start-stop

Conversation

@mhvk
Copy link
Contributor

@mhvk mhvk commented Dec 22, 2023

EDIT (2024-11-23): added tests and documentation, and move out of draft. I stuck to having start and stop be separate rows, so that one can easily pass (start, stop) (as in the examples). I feel that is most logical given the present structure, where one can see a single array as a start array that implies a default stop array. This can of course still be changed.

Rationale

ufuncs have a .reduceat method that allows having piecewise reductions, but using an array of indices that is rather convoluted. From
https://numpy.org/doc/stable/reference/generated/numpy.ufunc.reduceat.html, the indices are interpreted as follows:

For i in range(len(indices)), reduceat computes
ufunc.reduce(array[indices[i]:indices[i+1]]), which becomes the i-th
generalized "row" parallel to `axis` in the final result (i.e., <snip>).
There are three exceptions to this:

* when i = len(indices) - 1 (so for the last index), indices[i+1] = array.shape[axis].
* if indices[i] >= indices[i + 1], the i-th generalized “row” is simply array[indices[i]].
* if indices[i] >= len(array) or indices[i] < 0, an error is raised.

The exceptions are the main issue I have with the current definition: really, the current setup is only natural for contiguous pieces; for anything else, it requires contortion. For instance, the documentation describes how to get a
running sum as follows:

np.add.reduceat(np.arange(8),[0,4, 1,5, 2,6, 3,7])[::2]

Note the slice at the end to remove the unwanted elements! And note that this omits the last set of 4 elements -- to get this, one has to add a solitary index 4 at the end - one cannot get slices that include the last element except as the last one.

The PR arose from this unnatural way to describe slices: Why can one not just pass in the start and stop values directly? With no exceptions, but just interpreted as slices should be. I.e., get a running sum as,

np.add.reduceat(np.arange(8), ((start := np.arange(0, 8//2+1)), start+8//2))

Currently, the updated docstring explains the new mode as follows:

    There are two modes for how `indices` is interpreted. If it is a tuple of
    2 arrays (or an array with two rows), then these are interpreted as start
    and stop values of slices over which to compute reductions, i.e., for each
    row i, ``ufunc.reduce(array[indices[0, i]:indices[1, i]])`` is computed,
    which becomes the i-th element along `axis` in the final result (e.g., in
    a 2-D array, if ``axis=0``, it becomes the i-th row, but if ``axis=1``,
    it becomes the i-th column). Like for slices, negative indices are allowed
    for both start and stop, and the values are clipped to be between 0 and
    the shape of the array along `axis`.

The PR also adds a new initial keyword argument. The reason for this is that with the new layout I did not want to have the exception currently present, where if stop < start, one gets the value at start. I felt it was more logical to treat this case as an empty reduction, but then it becomes necessary to able to pass in an initial value for reductions that do not have an identity, like np.minimum (which of course just helps make reduceat more similar to reduce).

Note that I considered requiring slice(start, stop), which might be clearer. I only did not do that since implementation-wise just having a tuple or an array with 2 columns was super easy. I also liked that with this implementation the old way could at least in principle be described in terms of the new one, as having a default stop that just takes the next element of start (with the same exceptions as above). I ended not describing it as such in the docstring, though.

Anyway, if in principle it is thought a good idea to make reduceat more flexible, the API is up for discussion. It could require indices=slice(start, stop) (possibly step too), or one could allow not passing in indices if start and stop are present, or just add a stop keyword argument, whose defaults are interpreted as before.

Links

Old text

Triggered by #834 seeing some comments again, a draft just to see how it would look to allow reduceat to take a set of start, stop indices (treated as slices), to make the interface a bit more easily comprehensible without making a truly new method. It also allows passing in an initial to deal with empty slices.

Fixes #834

Mostly to discuss whether we want this at all, and, if so, what the API should be. So probably best not to worry too much about implementation (the duplication of code, both in reduceat itself and with reduce is large).

Sample use:

a = np.arange(12)
np.add.reduceat(a, ([1, 3, 5], [2, -1, 0]))
# array([ 1, 52,  0])
np.minimum.reduceat(a, ([1, 3, 5], [2, -1, 0]), initial=10)
# array([ 1,  3, 10])
np.minimum.reduceat(a, ([1, 3, 5], [2, -1, 0]))
# ValueError: empty slice encountered with reduceat operation for 'minimum', which does not have an identity. Specify 'initial'.

Writing it out like this, I think a different order may be useful, i.e., np.add(a, [(1, 2), (3, -1), (5, 0)]). The reason I picked the other one was that I liked the idea of triggering it by using slice(start, stop), with both start and stop possibly arrays and a tuple of two lists was closer to that (although internally it just turns it into an array). The list of tuples suggests more a structured array with start and stop (and step?) entries.

p.s. Fairly trivially extensible to start, stop, step.

@mhvk mhvk added this to the 2.0.0 release milestone Dec 22, 2023
@seberg seberg modified the milestones: 2.0.0 release, 2.1.0 release Jan 31, 2024
@mhvk mhvk removed this from the 2.1.0 release milestone Jan 31, 2024
@mhvk mhvk force-pushed the reduceat-start-stop branch 2 times, most recently from 2a6548b to 393ac30 Compare November 23, 2024 14:11
@mhvk mhvk marked this pull request as ready for review November 24, 2024 01:02
@mattip
Copy link
Member

mattip commented Jan 26, 2026

@mhvk this seems to have slipped between the cracks, and now has a small merge conflict. Would you like to continue with it?

Also introduces an initial argument, though it can only be used
with the start-stop format.
@mhvk mhvk force-pushed the reduceat-start-stop branch from 393ac30 to 9e46997 Compare January 27, 2026 00:43
@mhvk
Copy link
Contributor Author

mhvk commented Jan 27, 2026

Yes, I still think this would be useful. Unfortunately, the mailing list thread didn't generate much discussion... Anyway, I rebased and solved the merge conflict.

@seberg
Copy link
Member

seberg commented Jan 27, 2026

I suppose I never actually reviewed it because the mailing list discussion was a bit slow/inconclusive. But overall, I think we should do something, I just never made up my mind what precisely :/.

I was talking with @ikrommyd recently about this from an awkward array perspective. I am not actually sure it matches exactly, but yeah, we should maybe just do this.
The only thing would be if we have thoughts on the actual API.

E.g. start-stop indices with non-inclusive logic would also be very interesting (i.e. split-points), or using the word "segmented" for a new method (but it seems that JAX segmented_sum etc. are more like an add.at/bincount behavior).

@ikrommyd
Copy link
Contributor

ikrommyd commented Jan 27, 2026

Thanks for the ping @seberg. Since @mhvk appears to still be working on this it doesn't look like I need to pick this up. Happy to discuss the API of course.
FYI, take a look at this from nvidia cccl too: https://nvidia.github.io/cccl/python/compute_api.html#cuda.compute.algorithms.segmented_reduce

@seberg
Copy link
Member

seberg commented Jan 27, 2026

Thanks for that reference, that definition of segmented_reduce matches exactly what is implemented here. I do like that as a name, even if it mismatches JAX.
I wouldn't mind also having a split_points= kwarg alternative that reduces the identical chunks that np.split would return (but this can be a utility or a follow-up).

I actually am inclined to a ufunc.segemented_reduce(), since IMO new ufunc attributes are comparably cheap.

@ikrommyd
Copy link
Contributor

FWIW I also think a new segmented_reduce method feels better.

@mhvk
Copy link
Contributor Author

mhvk commented Jan 28, 2026

Thanks, both! I think at the moment the main question is actually one of API. Personally, I remain somewhat in favour of just using or expanding .reduceat -- it is not an illogical counterpart to .at and .reduce and what I have is, as the PR shows, easily done and explained as an extension. That said, I can see the advantage of introducing new keyword arguments instead of re-interpreting indices. I guess in that case start and end should be broadcast to determine the shape of the output? (Might have fun uses, though likely inefficient...)

Are you aware of other similar routines that might inform API? The cuda segmented_reduce is indeed very similar (though misses an axis argument), while JAX segment_sum is more like .at() (or bincount).

@ikrommyd
Copy link
Contributor

I am aware of this https://docs.pytorch.org/docs/stable/generated/torch.segment_reduce.html and https://www.tensorflow.org/api_docs/python/tf/math/segment_sum.
Pytorch's looks like it can accept both offsets and segment ids.
Segment IDs indeed looks mostly like .at

In [57]: offsets = np.array([0, 3, 3, 5, 6])

In [58]: content = np.array([1, 2, 3, 4, 5, 6])

In [59]: parents = np.array([0, 0, 0, 2, 2, 3])

In [60]: jax.ops.segment_sum(content, parents)
Out[60]: Array([6, 0, 9, 6], dtype=int32)

In [61]: num_segments = parents.max() + 1

In [62]: result = np.zeros(num_segments, dtype=content.dtype)

In [63]: np.add.at(result, parents, content)

In [64]: result
Out[64]: array([6, 0, 9, 6])

I think there's certain optimizations you can do if your segment ids are sorted (or you use offsets) given by jax's arg indices_are_sorted. I honestly like the start stop offsets like CCCL has the most (and what it looks to me exactly like what you have here).

@seberg
Copy link
Member

seberg commented Jan 28, 2026

Yeah, I guess the question is really whether we like the tuple way of passing this and what we like as a name.
FWIW, we could also just go with ufunc.reduceat(x, start=, stop=). Broadcasting start and end is OK, but am not sure we should care too much, for now it would be fine to enforce 1-D.

(Yes, that means that the current reduceat uses an optional positional argument, which is no problem.)

I still think segmented_reduce or reduce_segment(ed|s) may be worth it considering how reduceats API is normally what anyone needs. But we shouldn't stall the addition on that...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

reduceat cornercase (Trac #236)

4 participants