Skip to content

topk overhaul and argtopk#3405

Merged
jcrist merged 11 commits intodask:masterfrom
crusaderky:topk
Apr 23, 2018
Merged

topk overhaul and argtopk#3405
jcrist merged 11 commits intodask:masterfrom
crusaderky:topk

Conversation

@crusaderky
Copy link
Collaborator

topk changes:

Copy link
Member

@jcrist jcrist left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Overall this looks good to me. I left some comments on style and slight improvements, but this is pretty close to being ready to merge.

return a


def topk_postproc(a, k, axis):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These names show up in the graphs, and are only marginally shorter than the non-abbreviated versions. I'd prefer topk_postprocess and argtopk_preprocess.

return topk(self, k, axis=axis, split_every=split_every)

def argtopk(self, k, axis=-1, split_every=None):
"""The indexes of the top k elements of an array.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indexes -> indices

array([1, 3])
"""
if isinstance(a, int) and isinstance(k, Array):
warnings.warn("DeprecationWarning: topk(k, x) has been replaced with topk(a, k)")
Copy link
Member

@jcrist jcrist Apr 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be topk(k, a) has been replaced with topk(a, k)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the previous API really was topk(k, x). I renamed the array to a to be coherent with all other reduction functions.

Copy link
Member

@jcrist jcrist Apr 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but since the argument was positional, as a reader all I really care about is the swapping of the parameter order (array and k), not the name of the array parameter. As it currently reads it's not immediately clear that the order of the array and k parameter have swapped.



def argtopk(a, k, axis=-1, split_every=None):
"""Extract the indexes of the k largest elements from a on the given axis,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indexes -> indices throughout this docstring.

if isinstance(a, int) and isinstance(k, Array):
warnings.warn("DeprecationWarning: topk(k, x) has been replaced with topk(a, k)")
a, k = k, a

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the rest of this code assumes axis >= 0, you should use validate_axis here.

axis = validate_axis(a.ndim, axis)

This will handle negative axis, and throw a nice error for invalid values.

"""
# Convert a to a recarray that contains its index
if axis < 0:
axis += a.ndim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend using validate_axis here instead.

axis += a.ndim
idx = arange(a.shape[axis], chunks=a.chunks[axis])
idx = idx[tuple(slice(None) if i == axis else np.newaxis for i in range(a.ndim))]
a_rec = a.map_blocks(chunk.argtopk_preproc, idx, dtype=[('values', a.dtype), ('idx', int)])
Copy link
Member

@jcrist jcrist Apr 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On windows int will result in int32. It's not clear to me if that's incorrect or not, just thought I'd bring it up. It's important that the metadata matches the result (meaning x.dtype == x.compute().dtype, so if argtopk_preprocess produces int32 on windows than this is correct.

chunks=1)

# Support for deprecated API for topk
np.testing.assert_array_equal(da.topk(a, 5), da.topk(5, a))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check that the warning is thrown.

with pytest.warns(UserWarning):
    assert_eq(da.topk(a, 5), da.topk(5, a))


# As Array methods
a.topk(5)
np.testing.assert_array_equal(da.topk(a, 5), da.topk(5, a))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In dask we make use of assert_eq to test equality. This checks not only values, but also the metadata on the dask arrays, which helps ensure correctness. You should be able to replace all your calls to assert_array_equal with this.

np.testing.assert_array_equal(npf(b, axis=1)[:, :k, :],
daskf(b, -k, axis=1, split_every=se))
np.testing.assert_array_equal(npf(b, axis=2)[:, :, :k],
daskf(b, -k, axis=2, split_every=se))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests. Can I also ask that you add one with a negative axis?

@crusaderky
Copy link
Collaborator Author

crusaderky commented Apr 23, 2018

@jcrist I integrated your code review and overhauled the unit test.
Travis fails on something completely unrelated to what I touched - could you look into it please?

E   ModuleNotFoundError: No module named 'toolz'

@jcrist
Copy link
Member

jcrist commented Apr 23, 2018

Travis failure was due to an OOM error (if you scroll up you can see the originating exception). Restarted build and everything passes. Thanks for the contribution, merging.

@jcrist jcrist merged commit 0c443fd into dask:master Apr 23, 2018
for multiple axes, recursive aggregation, and an option to pick the bottom k elements instead.
(:pr:`3395`) `Guido Imperiale`_
- The ``topk`` API has changed from topk(k, array) to the more conventional topk(array, k).
The legacy API still works but is now deprecated. (:pr:`2965`) `Guido Imperiale`_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this suppose to be 2965 or was it meant to be 3395 as well?

@crusaderky crusaderky deleted the topk branch April 23, 2018 21:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Risk attribution: a[argsort(b)[:k]] Recursive topk Corner cases for topk of array

4 participants