Skip to content

Add move_to function to convert array namespace and device to namespace and device#31829

Merged
ogrisel merged 55 commits intoscikit-learn:mainfrom
lucyleeow:aapi_input_conversion
Nov 17, 2025
Merged

Add move_to function to convert array namespace and device to namespace and device#31829
ogrisel merged 55 commits intoscikit-learn:mainfrom
lucyleeow:aapi_input_conversion

Conversation

@lucyleeow
Copy link
Copy Markdown
Member

@lucyleeow lucyleeow commented Jul 24, 2025

Reference Issues/PRs

Towards #28668 and #31274

What does this implement/fix? Explain your changes.

Adds a function that converts arrays to the namespace and device of the reference array.

Tries DLPack first, and if either array does not support it, tries to convert manually.

Any other comments?

This is an initial attempt, and what it would look like in a simple metric. Feedback welcome. (Tests to come)

I thought about also outputting the namespace and device of the reference array, to avoid the second call to get_namespace_and_device, but I thought it would make the outputs too messy.

cc @ogrisel @betatim @StefanieSenger @virchan @lesteve

@github-actions
Copy link
Copy Markdown

github-actions bot commented Jul 24, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 71aaad2. Link to the linter CI: here

try:
# Note will copy if required
array_converted = xp_ref.from_dlpack(array, device=device_ref)
except AttributeError:
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to only except AttributeError, which I think occurs if input or output namespace does not support dlpack.

from_dlpack can give 2 (or more) other errors

  • BufferError - The dlpack and dlpack_device methods on the input array may raise BufferError when the data cannot be exported as DLPack (e.g., incompatible dtype, strides, or device). It may also raise other errors when export fails for other reasons (e.g., not enough memory available to materialize the data). from_dlpack must propagate such exceptions.
    • I thought that if dlpack fails to convert due to one of the above errors, it would not make sense to try ourselves manually.
  • ValueError - If data exchange is possible via an explicit copy but copy is set to False.
    • I've left copy=None, allowing it to copy if need be, so this error is not relevant. I am not sure about the copy=None setting though, it is a lot of memory usage.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talking to Evgeni about what could possibly cause a BufferError - here are some ideas, but but I don't know if any of these will cause an error:

Copy link
Copy Markdown
Member

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we could simplify _convert_to_reference a bit when xp_ref is NumPy:

if _is_numpy_namespace(xp_ref):
   return tuple([_convert_to_numpy(array, get_namespace(array)) for array in arrays])

However, I'm not sure this offers much benefit beyond readability, since in most cases there are only two arrays to convert: y_true and sample_weight.

I'll have to give it some more thought.

@lucyleeow
Copy link
Copy Markdown
Member Author

I suspect we could simplify _convert_to_reference a bit when xp_ref is NumPy:

At the moment that would be simpler, but I think _convert_to_numpy is considered a bit of a 'hack'. It's got numpy conversions for specific array namespaces hard coded (needs maintenance, though I do doubt any of the APIs will change) and it doesn't necessarily work for all array API arrays. I think DLPack is more future proof and it has a copy parameter (we could specify to avoid copying if we wanted to) - at least this was the thinking when I decided to just try DLPack first...

since in most cases there are only two arrays to convert: y_true and sample_weight.

For metrics I think this would mostly be it, but for estimators there could be other arrays to convert.

@jeremiedbb
Copy link
Copy Markdown
Member

What's the difference with sklearn.utils._array_api.ensure_common_namespace_device ? Looks like both are trying to do the same thing.

@betatim
Copy link
Copy Markdown
Member

betatim commented Jul 25, 2025

What's the difference with sklearn.utils._array_api.ensure_common_namespace_device ? Looks like both are trying to do the same thing.

I think the goal of both is the same. I also think that ensure_common_namespace_device is broken right now, at least for the application that Lucy has in mind. Maybe the thing to do is to replace the code of ensure_common_namespace_device with that of the new function?

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would be in favor of renaming this method to move_to(arrays, namespace, device) since the device and namespace will already be inspected in the caller most of the time for other purposes.

@betatim
Copy link
Copy Markdown
Member

betatim commented Aug 27, 2025

What's the difference with sklearn.utils._array_api.ensure_common_namespace_device ? Looks like both are trying to do the same thing.

I think we should remove/edit ensure_common_namespace_device as part of this PR and use move_to instead.

@betatim
Copy link
Copy Markdown
Member

betatim commented Aug 28, 2025

We should probably create a common test that checks that "everything follows X/y_pred" works for estimators and functions. But it will require quite a bit of renovation work in existing code. I think this means we should tackle this in a new PR. And even there I wonder if we should add the test and make all the changes or if we should have several smaller PRs that make the change and then add the test. WDYT?

@lucyleeow
Copy link
Copy Markdown
Member Author

We should probably create a common test that checks that "everything follows X/y_pred" works for estimators and functions.

Why don't I add a unit test for move_to in this PR and add a common "everything follows X/y_pred" for estimators and metrics in a separate PR. I feel like that will take some iterating and thought...

@lucyleeow
Copy link
Copy Markdown
Member Author

Looking into adding a unit test to test different array combinations, I have a stupid question - is it worth testing arrays on different devices but in the same namespace? Namely torch. I think it would be testing torch cpu -> torch cuda but something like torch cuda -> torch mps seems unlikely to occur in real life?

Copy link
Copy Markdown
Member

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your work, @lucyleeow. I only have two nit comments and a question: Why don't we test for all possible combinations here? I have read the whole discussion and have seen how the test evolved from being joined, then split, then joined again. This is a central function of array API and I feel it is necessary to test for all possible combinations.

This is only a partial review and I only intend to comment here.

# methods are not present on the input array
# `TypeError` and `NotImplementedError` for packages that do not
# yet support dlpack 1.0
# (i.e. the `device`/`copy` kwargs, e.g., torch <= 2.8.0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line and L518 talk about copy= as a keyword argument. But I can't see it being used anywhere. So I'm confused. Is the comment out of date? Or is it on purpose that the comment mentions it but it isn't in the code?

Copy link
Copy Markdown
Member Author

@lucyleeow lucyleeow Nov 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, sorry this assumes knowledge!

The two relevant PRs are data-apis/array-api#741 and dmlc/dlpack#136 (which references the first PR). DLPack v1 added both the copy and device kwargs at the same time;

  1. Explicitly enable/disable data copies (through the new copy argument). The default is still None ("may copy") to preserve backward compatibility and align with other constructors (ex: asarray).
  2. Allow copying to CPU (through the new device argument, plus setting copy=True)

thus I speak of them together, as when array libraries (e.g., pytorch) add support for DLPack v1, they (generally) add support for both kwargs together.

The default value for copy=None, which means only copy if needed. (this kwarg can also be True or False - https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html#from-dlpack)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I can add reference to data-apis/array-api#741 ?

@betatim
Copy link
Copy Markdown
Member

betatim commented Nov 14, 2025

I'd propose the following in an attempt to wrap up this PR (it has already gone down some rabbit holes and generated a lot of comments :D).

Address the open comments (that have come since Olivier's review), do not add more new review comments (that aren't responses to existing ones), merge #32705 and then merge this. Then find out what we forgot and address it.

WDYT?

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Nov 14, 2025

I merged #32705, let me sync with main and retrigger CUDA CI.

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Nov 14, 2025

@lucyleeow there are a few suggestions/comments left to address in the above reviews, but otherwise LGTM for merge once addressed.

@lucyleeow
Copy link
Copy Markdown
Member Author

Reviews addressed and I removed the RunTimeError from the except, which will check that #32705 did what it was supposed to! 🤞

@ogrisel ogrisel enabled auto-merge (squash) November 17, 2025 08:48
@github-actions github-actions bot removed the CUDA CI label Nov 17, 2025
@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Nov 17, 2025

I expanded the inline comment based on review feedback, synced with main and marked as auto-merge. Thanks @lucyleeow and reviewers!

@ogrisel ogrisel merged commit 93311ba into scikit-learn:main Nov 17, 2025
42 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Array API Nov 17, 2025
@lucyleeow lucyleeow deleted the aapi_input_conversion branch November 18, 2025 01:48
@lucyleeow
Copy link
Copy Markdown
Member Author

Thanks all for your time and your thoughts!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

7 participants