Add array API support for _weighted_percentile#29431
Add array API support for _weighted_percentile#29431betatim merged 49 commits intoscikit-learn:mainfrom
Conversation
|
cc @lesteve as well. |
There was a problem hiding this comment.
Hey @EmilyXinyi! Thank you so much for your work. ✨
I went through it and commented to the best of my knowledge. I will probably err on some things, since I am also only learning about Array API, and I hope this is not too confusing. Surely a maintainer can help more than me. But maybe my comments can help us into some kind of productive discussion.
I have yet to check _take_along_axis(), which I will do later, when I have a clearer head for this.
Edit: I have done it now.
sklearn/utils/stats.py
Outdated
| xp, _ = get_namespace(array) | ||
| n_dim = array.ndim | ||
| if n_dim == 0: | ||
| return array[()] |
There was a problem hiding this comment.
As far as I have understood, we still would want to return a scalar value from the same namespace (and same device) as the original array in the case when the dimension of the original array is 0. I'm not sure about the correct syntax though.
There was a problem hiding this comment.
I think this code will keep working and do the right thing for array API inputs. At least the following works:
In [12]: x_ = array_api_strict.asarray(1)
In [13]: x_.ndim
Out[13]: 0
In [14]: x[()]
Out[14]: np.int64(1)(but maybe this only works because array api strict uses numpy in the background?
There was a problem hiding this comment.
Thank you @betatim.
I have a question: what do you mean when you say, the code works? I am not sure what the distinction criterium for working and not working is, since we won't get any error either way. Do we need to measure performance?
If you are looking for the data type: Why did you expect np.int64 here?
There was a problem hiding this comment.
I think what he meant is the indexing a 0d-array with an empty tuple preserves the namespace and the dtype (and presumably the device) of the original array.
There was a problem hiding this comment.
Actually, I misread the output of the snippet. Indeed, the namespace is not preserved since this returns a NumPy scalar instance and not a 0d array of the original namespace.
But this is somewhat but not exactly in line with the return type described in the docstring of the function (int).
>>> import numpy as np
>>> np.isscalar(np.asarray(1))
False
>>> np.isscalar(np.asarray(1)[()])
True
>>> np.isscalar(np.int64(1))
True
>>> import numbers
>>> isinstance(np.int64(1), numbers.Integral)
True
>>> isinstance(np.int64(1), int)
FalseThere was a problem hiding this comment.
Actually I cannot reproduce Tim's snippet with array-api-strict 2.3:
>>> import array_api_strict as xp
>>> xp.asarray(1)[()]
Array(1, dtype=array_api_strict.int64)
>>> xp.__version__
'2.3'So this code is actually returning a 0d array of the original namespace (and device) rather than a NumPy scalar instance.
There was a problem hiding this comment.
And torch does the same:
>>> import array_api_compat.torch as xp
>>> xp.asarray(1, device="mps")[()]
tensor(1, device='mps:0')There was a problem hiding this comment.
I think the caller will be responsible for calling float(output) if we want to convert a NumPy scalar or a 0d (non-NumPy) array value into a Python float scalar in the end.
For 0D NumPy inputs, returning a NumPy scalar or a 0D NumPy array should not matter much. For other kinds of 0D input arrays, we would always return 0D output arrays (because the array API spec does not support namespace-specific scalar values, but only 0D arrays).
We could probably remove the [()] in the 0D input case. I don't think it matters, neither for NumPy inputs nor other kinds of inputs. I tried to remove it on main and all the tests still pass.
There was a problem hiding this comment.
@ogrisel to clarify, API spec does not support namespace-specific scalar values, so:
array_api_strict.asarray(1)[()]array_api_strict.asarray(1)
both result in the same thing, a 0D array ?
Whereas numpy does have scalars so
np.array(1)[()]np.array(1)
result in different things, 0D array vs scalar.
So I think I agree that we should remove the [()] for 0D input case? (so numpy output is the same as other API arrays?)
There was a problem hiding this comment.
General comment: most of the time when we add array API support to a function in scikit-learn, we do not touch the existing (numpy-only) tests to make sure that the PR does not change the default behavior of scikit-learn on traditional inputs when array API is not enabled.
Instead we add a few new test functions that:
- generate some random test data with numpy or
sklearn.datasets.make_*; - call the function once on the numpy inputs without enabling array API dispatch;
- convert the inputs to a namespace / device combo passed as parameter to the test;
- call the function with array API dispatching enabled (under a
with sklearn.config_context(array_api_dispatch=True)block - check that the results are on the same namespace and device as the input
- convert back the output to a numpy array using
_convert_to_numpy - compare the original / reference numpy results and the
xpcomputation results converted back to numpy usingassert_allcloseor similar.
Those tests should have array_api somewhere in their name to makes sure that we can run all the array API compliance tests with a keyword search in the pytest command line, e.g.:
pytest -k array_api sklearn/some/subpackage
In particular, for cost reasons, our CUDA GPU CI only runs pytest -k array_api sklearn. So it's very important to respect this naming conventions, otherwise we will not tests all what we are supposed to test on CUDA.
Awesome! This implies that we don't test Array API down every control flow branch (like codecov), but rather have one new test per function. This is also how it is meant, correct? |
Ideally we don't even need one test per function when the existing common tests are enough as we did in many past PRs for:
scikit-learn/sklearn/utils/estimator_checks.py Lines 143 to 145 in d9deffe e.g. for PCA: scikit-learn/sklearn/decomposition/_pca.py Lines 850 to 854 in d9deffe
But we can also write custom test functions (with scikit-learn/sklearn/decomposition/tests/test_pca.py Lines 1007 to 1036 in d9deffe Note that this particular wrapper of a common estimator check as a custom check will probably be refactored once #29820 is merged, but the main idea is that it's ok to add a few extra tests when the existing common tests / estimator checks are not enough to test parameter combinations we would like to cover. The point is that the existing (pre-array API) tests should be enough to check that all branches works as expected with numpy and then on top of that array compliance checks check that numpy and any other array API library find the same results (up to rounding errors) on random data (with various combinations of extra parameters that should support array API). |
|
I updated the description of the meta issue to add testing guidelines. |
|
Thanks for the explanations about testing and updating the issue description, @ogrisel.
This is something I am still wrapping my head around. Testing if the (old) numpy outputs match the array api outputs for the default params seems incomplete to me. In my mind, the default is nothing special compared to other branches of a function. This doubt was infact why I wrote this down so explicitly to get your confirmation. I will keep thinking about it. |
ogrisel
left a comment
There was a problem hiding this comment.
Another pass of feedback w.r.t. missing array API functions.
sklearn/utils/_array_api.py
Outdated
| if _is_numpy_namespace(xp): | ||
| return numpy.nextafter(x1, x2) |
There was a problem hiding this comment.
| if _is_numpy_namespace(xp): | |
| return numpy.nextafter(x1, x2) | |
| if hasattr(xp, "nextafter"): | |
| return xp.nextafter(x1, x2) |
That's the expected behaviour. Maybe we could change it though. |
|
Ok I will stop pushing to your PR for today @lucyleeow ;) Feel free to resume addressing the remaining open thread in the review above. |
|
FYI I've opened a PR to fix the array-api-strict bug (data-apis/array-api-strict#139), but I guess we'd have to wait for a release to remove the xfail? Note it looks codecov is complaining about pytest.xfail(f"xp.nextafter is broken on {device}") |
Yes. And a weekly lock file update. |
ogrisel
left a comment
There was a problem hiding this comment.
LGTM. Thank you very much @lucyleeow @EmilyXinyi @StefanieSenger.
This is ready for another second review (ping @OmarManzoor @betatim).
|
FYI data-apis/array-api-strict#139 is now merged, I'll keep an eye on the new release. Thanks for your review and fixes @ogrisel ! |
|
For information, I opened scipy/scipy#22794 so that in the long term, we could consolidate quantile implementation with array API + NaN + weight support all in SciPy instead of scikit-learn. SciPy already has array API and NaN support but lacks:
|
betatim
left a comment
There was a problem hiding this comment.
This looks nice! I've not followed the whole conversation in the ~125 comments, only looked at the code as it is now.
I had two questions, but otherwise I think this is nice and ready to go!
| pass | ||
| else: | ||
| if device == array_api_strict.Device("device1"): | ||
| # See https://github.com/data-apis/array-api-strict/issues/134 |
There was a problem hiding this comment.
It looks like data-apis/array-api-strict#139 fixed the bug, so do we still need this xfail? Asked differently, when can we remove this xfail? Maybe we can write it down in the comment to help us-from-the-future
There was a problem hiding this comment.
Can you tell I didn't read the conversation :D This is answered in #29431 (comment) - maybe still worth a comment here or do you think the release will happen soon enough that you still remember @lucyleeow ?
There was a problem hiding this comment.
I agree, I should add a comment here, and I will add this to my to do list to follow up
|
In the last monthly meeting we discussed "why do people not merge PRs once they have enough +1s?" - so in the spirit of merging things a bit more promptly: I'll merge this now. It looks like the comment threads not marked as "resolved" are either out of date or questions. If there is one that leads to some changes, lets do that in a new PR. Thanks for the persistence in this long running PR |
|
@betatim you beat me by minutes! I guess if I don't add that comment, we will have to rely solely on me remembering to do it 😱 I've added it to my todo. |
|
Thanks for the review and speedy merge :D |
TO DO:
_weight_percentileto support array APItest_stats.pycc: @StefanieSenger