Skip to content

TST Add common test for mixed array API inputs for metrics#32755

Merged
ogrisel merged 75 commits intoscikit-learn:mainfrom
lucyleeow:tst_mixed_input_metric
Mar 6, 2026
Merged

TST Add common test for mixed array API inputs for metrics#32755
ogrisel merged 75 commits intoscikit-learn:mainfrom
lucyleeow:tst_mixed_input_metric

Conversation

@lucyleeow
Copy link
Copy Markdown
Member

Reference Issues/PRs

Follow up to #31829
Related to #31274 (and loosely #28668)

What does this implement/fix? Explain your changes.

Adds a common test to check that a metric supports. The idea is that we can just add metrics to this common test as we add support.

Had many questions during implementation, but thought it would still be nice to have a draft PR to look at.

  • I don't check that the value output is as expected (i.e. the same as if I had used numpy inputs). I thought it wasn't necessary, as we check the accuracy elsewhere. It checks
    • for float outputs, no error is raised
    • for array outputs, no error is raised and the output array is of the correct namespace/device
  • checks with different array namespaces for y_true and y_pred - sample weight is not used, and other kwargs (e.g., for continuous metrics cycling through "average": ("micro", "macro", "weighted")) are not checked

I'll write other items as comments on the code.

The other implementation I thought about was to include it in check_array_api_metric by adding an optional check_mixed_input kwarg. This avoids the 'messy' conditionals at the start of the current test_mixed_namespace_input_compliance where we work out what type of input should go in. But this approach had other problems:

  • we'd have to chain the parametrizations yield_namespace_device_dtype_combinations and the new mixed array combinations in test_array_api_compliance, and somehow skip the mixed input for metrics that don't yet support it
  • to turn on mixed input checking we'd have to use partial(check_array_api_*, check_mixed_input=True)
  • we'd need to add kwargs so check_array_api_metric to handle accepting 2 namespace/devices

Any other comments?

cc @ogrisel @betatim @lesteve

@github-actions
Copy link
Copy Markdown

github-actions bot commented Nov 21, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: b54509d. Link to the linter CI: here

Comment on lines +2460 to +2478
if metric_name in CLASSIFICATION_METRICS:
# These should all accept binary label input as there are no
# `CLASSIFICATION_METRICS` that are in `METRIC_UNDEFINED_BINARY` and are NOT
# `partial`s (which we do not test for in array API)
y1 = xp_input.asarray([0, 0, 1, 1], device=array_input.device)
y2 = xp_ref.asarray([0, 1, 0, 1], device=reference.device)
elif metric_name in {**CONTINUOUS_CLASSIFICATION_METRICS, **CURVE_METRICS}:
if metric_name not in METRIC_UNDEFINED_BINARY:
# Continuous binary input
y1 = xp_input.asarray([0.5, 0.2, 0.7, 0.6], device=array_input.device)
y2 = xp_ref.asarray([1, 0, 1, 0], device=reference.device)
else:
# Continuous but shape (n_samples, n_labels)
y1 = xp_input.asarray([[0.5, 0.2, 0.7, 0.6]], device=array_input.device)
y2 = xp_ref.asarray([[1, 0, 1, 0]], device=reference.device)
elif metric_name in REGRESSION_METRICS:
y1 = xp_input.asarray([2.0, 0.1, 1.0, 4.0], device=array_input.device)
y2 = xp_ref.asarray([0.5, 0.5, 2, 2], device=reference.device)
elif metric_name in PAIRWISE_METRICS:
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not familiar with metric types, I think I got this right but open to improvements!

@lucyleeow
Copy link
Copy Markdown
Member Author

lucyleeow commented Nov 27, 2025

Gentle ping on this. We have already started adding support for mixed arrays in some metrics (#32422) - and we test mixed array (by checking for string numpy y_true inputs with y_pred array API inputs from yield_namespace_device_dtype_combinations). I think it is probably better to separate testing of array API support vs support of mixed array inputs (points listed here #32793 (comment)).

Note also that it would be useful to get #32793 in before we add more continuous metrics, but that hinges on this so it would be nice to push this forward!

@betatim
Copy link
Copy Markdown
Member

betatim commented Nov 27, 2025

I don't check that the value output is as expected (i.e. the same as if I had used numpy inputs). I thought it wasn't necessary, as we check the accuracy elsewhere.

Sounds like a good plan. For estimators we do something similar, the basic test just checks that the namespace, device and shapes match. As well as in the estimator specific tests where the values are compared

@lucyleeow
Copy link
Copy Markdown
Member Author

lucyleeow commented Dec 1, 2025

Note: I have realised that there are certain metrics that will allow string numpy y_true when y_pred is not np (e.g., those from #32422), because we never need to use both y_true and y_pred together (e.g., y_true == y_pred or y_true - y_pred) - thus we never move y_true to y_pred -> which is why the test test_probabilistic_metrics_array_api with string labels works.

Since I want to add testing for string y_true, I will make (yet another) list, of SUPPORT_Y_TRUE_STRING metrics. I think we should also mention this in array_api.rst, as it would be useful to know which metrics also support string numpy in mixed input setting. (edit: is maintaining these lists going to be a burden? I think it's okay?)

Related: https://github.com/scikit-learn/scikit-learn/pull/32422/files#r2575666004

cc @betatim @ogrisel

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Mar 3, 2026

Did we change anything in our CI's (in case you know off the top of your head @lesteve @ogrisel ?)

We did change the channels used to download CUDA related stuff in #33212. Not 100% sure this is related.

EDIT: sorry for the noise, it's not related because we have a move_to failure on the MPS CI (unrelated to the lockfile of the CUDA CI).

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Mar 3, 2026

BufferError – The dlpack and dlpack_device methods on the input array may raise BufferError when the data cannot be exported as DLPack (e.g., incompatible dtype, strides, or device).

I read the meaning as: incompatible dtype, incompatible strides, or incompatible device - which would suggest a BufferError should be raised here?

I agree with your analysis but I am not sure which project should this issue be reported to: NumPy or PyTorch as the traceback is not explicit about which library raises the RuntimeError (I suppose it's all happening in compiled code).

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Mar 3, 2026

Actually, for the torch MPS to NumPy conversion we could leverage DLPACK but only if we explicitly pass the CPU device which is understood by both NumPy and PyTorch:

>>> import torch, numpy as np
>>> x_torch = torch.arange(5, device='mps')
>>> np.from_dlpack(x_torch, device=None)
Traceback (most recent call last):
  Cell In[16], line 1
    np.from_dlpack(x_torch, device=None)
RuntimeError: Unsupported device in DLTensor.
>>> np.from_dlpack(x_torch, device="cpu")
array([0, 1, 2, 3, 4])

@lucyleeow
Copy link
Copy Markdown
Member Author

lucyleeow commented Mar 4, 2026

I agree with your analysis but I am not sure which project should this issue be reported to: NumPy or PyTorch as the traceback is not explicit about which library raises the RuntimeError

I am going to say NumPy because I cannot find "Unsupported device in DLTensor" in the PyTorch codebase, but i can in the NumPy codebase. Also I get the same error when converting from both CuPy and PyTorch MPS to numpy.

for the torch MPS to NumPy conversion we could leverage DLPACK but only if we explicitly pass the CPU device

Reading the from_dlpack docs this does seem to correct:

device on which to place the created array. If device is None and x supports DLPack, the output array must be on the same device as x. Default: None.

If the device=None it tries to put the output array on torch MPS.

I wonder if we should specify NumPy device to be "cpu" instead of None? Relevant PRs: #29119 , #30454

Edit: I guess if we are careful in specifying device to be the output of device(), it would be fine.

@StefanieSenger
Copy link
Copy Markdown
Member

If the device=None it tries to put the output array on torch MPS.

I wonder if we should specify NumPy device to be "cpu" instead of None? Relevant PRs: #29119 , #30454

Edit: I guess if we are careful in specifying device to be the output of device(), it would be fine.

The device() function from array_api_compat does a

if is_numpy_array(x):
    return "cpu"

For this PR I think it's fine to use it and we should not let this question block it.
Independent from this PR, I think it's beneficial to check if specifying NumPy device to be "cpu" instead of None has any bad side effects.

There are only two open points for me, otherwise this PR looks good:

  1. TST Add common test for mixed array API inputs for metrics #32755 (comment)
  2. TST Add common test for mixed array API inputs for metrics #32755 (comment)

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another pass of feedback.

@lucyleeow
Copy link
Copy Markdown
Member Author

Independent from this PR, I think it's beneficial to check if specifying NumPy device to be "cpu" instead of None has any bad side effects.

Yes I did not mean to do this within this PR but I am not sure whether this is a problem, reading #29119 , #30454 I can see why we have it as None...

Copy link
Copy Markdown
Member

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the newest simplifictions on the test case building, this PR looks fine to me.
Thanks for the work, @lucyleeow! Approving.

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you very much @lucyleeow!

@ogrisel ogrisel merged commit 81e7a3e into scikit-learn:main Mar 6, 2026
41 of 42 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Array API Mar 6, 2026
@github-project-automation github-project-automation bot moved this from In progress to Done in Labs Mar 6, 2026
@lucyleeow
Copy link
Copy Markdown
Member Author

Thanks all!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

6 participants