Add tests for train_test_split with Array API input#26855
Add tests for train_test_split with Array API input#26855thomasjpfan merged 7 commits intoscikit-learn:mainfrom
Conversation
ogrisel
left a comment
There was a problem hiding this comment.
Great that it works out of the box.
| def __eq__(self, other): | ||
| return self._namespace == other._namespace |
There was a problem hiding this comment.
Do we want this? It is convenient in the test to be able to compare (wrapped) namespaces for equivalence.
There was a problem hiding this comment.
I think i prefer explicit namespace assertions 8n tests. We could have a helper to assert same namespace in tests.
There was a problem hiding this comment.
What do you mean with explicit? Getting a string representation that we can compare to the array_namespace passed in to the test?
In the test itself I use get_namespace(input)[0] == get_namespace(output)[0] to check that input and output are in the same namespace. This works when the namespace is one from the array compat library, but not for the few namespaces that we wrap in this wrapper.
There was a problem hiding this comment.
I am okay with overriding __eq__ like this.
|
@thomasjpfan and @ogrisel - if you want to look at a PR that mostly adds new tests, this is one :D |
| def __eq__(self, other): | ||
| return self._namespace == other._namespace |
There was a problem hiding this comment.
I think i prefer explicit namespace assertions 8n tests. We could have a helper to assert same namespace in tests.
| def __eq__(self, other): | ||
| return self._namespace == other._namespace |
There was a problem hiding this comment.
I am okay with overriding __eq__ like this.
fcb0edf to
8a7814d
Compare
|
Should we list this kind of thing (functions, not estimators) in the "estimators with support" section of |
I like a new section. I think it's good to keep track of all the Array API supported estimators & functions in |
|
What do you think of the current patch? I added subsections, one called Estimators and one called Tools. |
That is okay with me. |
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Reference Issues/PRs
(need to find one)
What does this implement/fix? Explain your changes.
This mostly adds some tests that use
train_test_splitwith Array API input and compare to using a pure Numpy array as input.Any other comments?
First attempt of seeing what happens when you feed cupy/pytorch/array api arrays to
train_test_split. Need to explore more of the different parameters to see if they all "just work".