Conversation
jcrist
left a comment
There was a problem hiding this comment.
Thanks for the PR. Overall this looks good to me. I left some comments on style and slight improvements, but this is pretty close to being ready to merge.
dask/array/chunk.py
Outdated
| return a | ||
|
|
||
|
|
||
| def topk_postproc(a, k, axis): |
There was a problem hiding this comment.
These names show up in the graphs, and are only marginally shorter than the non-abbreviated versions. I'd prefer topk_postprocess and argtopk_preprocess.
dask/array/core.py
Outdated
| return topk(self, k, axis=axis, split_every=split_every) | ||
|
|
||
| def argtopk(self, k, axis=-1, split_every=None): | ||
| """The indexes of the top k elements of an array. |
dask/array/reductions.py
Outdated
| array([1, 3]) | ||
| """ | ||
| if isinstance(a, int) and isinstance(k, Array): | ||
| warnings.warn("DeprecationWarning: topk(k, x) has been replaced with topk(a, k)") |
There was a problem hiding this comment.
Should be topk(k, a) has been replaced with topk(a, k)
There was a problem hiding this comment.
No, the previous API really was topk(k, x). I renamed the array to a to be coherent with all other reduction functions.
There was a problem hiding this comment.
Yes, but since the argument was positional, as a reader all I really care about is the swapping of the parameter order (array and k), not the name of the array parameter. As it currently reads it's not immediately clear that the order of the array and k parameter have swapped.
dask/array/reductions.py
Outdated
|
|
||
|
|
||
| def argtopk(a, k, axis=-1, split_every=None): | ||
| """Extract the indexes of the k largest elements from a on the given axis, |
There was a problem hiding this comment.
indexes -> indices throughout this docstring.
| if isinstance(a, int) and isinstance(k, Array): | ||
| warnings.warn("DeprecationWarning: topk(k, x) has been replaced with topk(a, k)") | ||
| a, k = k, a | ||
|
|
There was a problem hiding this comment.
Since the rest of this code assumes axis >= 0, you should use validate_axis here.
axis = validate_axis(a.ndim, axis)This will handle negative axis, and throw a nice error for invalid values.
dask/array/reductions.py
Outdated
| """ | ||
| # Convert a to a recarray that contains its index | ||
| if axis < 0: | ||
| axis += a.ndim |
There was a problem hiding this comment.
I recommend using validate_axis here instead.
dask/array/reductions.py
Outdated
| axis += a.ndim | ||
| idx = arange(a.shape[axis], chunks=a.chunks[axis]) | ||
| idx = idx[tuple(slice(None) if i == axis else np.newaxis for i in range(a.ndim))] | ||
| a_rec = a.map_blocks(chunk.argtopk_preproc, idx, dtype=[('values', a.dtype), ('idx', int)]) |
There was a problem hiding this comment.
On windows int will result in int32. It's not clear to me if that's incorrect or not, just thought I'd bring it up. It's important that the metadata matches the result (meaning x.dtype == x.compute().dtype, so if argtopk_preprocess produces int32 on windows than this is correct.
dask/array/tests/test_reductions.py
Outdated
| chunks=1) | ||
|
|
||
| # Support for deprecated API for topk | ||
| np.testing.assert_array_equal(da.topk(a, 5), da.topk(5, a)) |
There was a problem hiding this comment.
We should check that the warning is thrown.
with pytest.warns(UserWarning):
assert_eq(da.topk(a, 5), da.topk(5, a))
dask/array/tests/test_reductions.py
Outdated
|
|
||
| # As Array methods | ||
| a.topk(5) | ||
| np.testing.assert_array_equal(da.topk(a, 5), da.topk(5, a)) |
There was a problem hiding this comment.
In dask we make use of assert_eq to test equality. This checks not only values, but also the metadata on the dask arrays, which helps ensure correctness. You should be able to replace all your calls to assert_array_equal with this.
dask/array/tests/test_reductions.py
Outdated
| np.testing.assert_array_equal(npf(b, axis=1)[:, :k, :], | ||
| daskf(b, -k, axis=1, split_every=se)) | ||
| np.testing.assert_array_equal(npf(b, axis=2)[:, :, :k], | ||
| daskf(b, -k, axis=2, split_every=se)) |
There was a problem hiding this comment.
Nice tests. Can I also ask that you add one with a negative axis?
|
@jcrist I integrated your code review and overhauled the unit test. |
|
Travis failure was due to an OOM error (if you scroll up you can see the originating exception). Restarted build and everything passes. Thanks for the contribution, merging. |
| for multiple axes, recursive aggregation, and an option to pick the bottom k elements instead. | ||
| (:pr:`3395`) `Guido Imperiale`_ | ||
| - The ``topk`` API has changed from topk(k, array) to the more conventional topk(array, k). | ||
| The legacy API still works but is now deprecated. (:pr:`2965`) `Guido Imperiale`_ |
There was a problem hiding this comment.
Was this suppose to be 2965 or was it meant to be 3395 as well?
topk changes: