Skip to content

Add support for array API to RidgeCV, RidgeClassifier and RidgeClassifierCV#27961

Merged
ogrisel merged 63 commits intoscikit-learn:mainfrom
jeromedockes:ridgecv-arrayapi
Oct 13, 2025
Merged

Add support for array API to RidgeCV, RidgeClassifier and RidgeClassifierCV#27961
ogrisel merged 63 commits intoscikit-learn:mainfrom
jeromedockes:ridgecv-arrayapi

Conversation

@jeromedockes
Copy link
Copy Markdown
Contributor

@jeromedockes jeromedockes commented Dec 14, 2023

Reference Issues/PRs

Towards #26024.

This PR extends the one for Ridge (still WIP, #27800) to use the array API in RidgeCV and RidgeClassifierCV (when cv="gcv")

What does this implement/fix? Explain your changes.

this could make those estimators faster as an important part of their computational cost is due to compute either an eigendecomposition of XX^T or an SVD of X

Any other comments?

The _RidgeGCV has numerical precision issues when computations are done in float32, which is why ATM in the main branch it always uses float64
I'm not sure what should be done for array API inputs on devices that do not have float64

not handled yet:

  • RidgeClassifierCV

@github-actions
Copy link
Copy Markdown

github-actions bot commented Dec 14, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 8f3f0d9. Link to the linter CI: here

@jeromedockes
Copy link
Copy Markdown
Contributor Author

I think the test failures for Ridge and RidgeCV arise from r2_score and will be handled in #27904
For RidgeClassifierCV we need to support the array API in LabelBinarizer

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Mar 14, 2024

While I am thinking about it, please don't forget to update:

if sparse.issparse(X):
dtype = np.float64
else:
dtype = [xp.float64, xp.float32]
Copy link
Copy Markdown
Member

@ogrisel ogrisel May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contrary to what I said in this morning meeting, I think we might want to implement the following logic:

  • if the input namespace/device supports xp.float64 upcasting, then do the upcast (as we currently do with NumPy)
  • if not (e.g. pytorch + MPS device combination), accept that we have degraded numerical performance, adjust the tolerance in the tests accordingly and document this limited numerical precision guarantee in our Array API doc.

I think this is the strategy we are leaning towards in the review of #27113. During the review of the r2_score PR, I believe that @adrinjalali preferred that approach.

In a future PR, we might decide to drop the float32 -> float64 upcast in general for this estimator (as it silently triggers a potentially very large and unexpected memory allocation which is a usability problem in itself, even with NumPy) but I would rather make this decision independently of Array API support.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would you recommend I check if the upcasting is possible? should I temporarily copy the max_precision_float_dtype and supported_float_dtypes changes from 27113 until it is merged? or is there already a utility in scikit-learn for checking that which I missed?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to copy with a TODO comment to remove redundant code once #27113 is merged to be able to decouple the 2 reviews.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we do the upcast with what precision should we store the coefficients and intercept? I guess for prediction we do not need the extra precision so we should use X's original dtype?

Copy link
Copy Markdown
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat! From my point of view LGTM. But I haven't checked the tests or mathematical correctness.

@OmarManzoor OmarManzoor changed the title Add support for array API to RidgeCV Add support for array API to RidgeCV, RidgeClassifier and RidgeClassifierCV Oct 10, 2025
@ogrisel ogrisel removed the Stalled label Oct 10, 2025
@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Oct 10, 2025

I removed the stalled label since @OmarManzoor is pushing new commits to finalize this important PR.

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Oct 10, 2025

Hum, _atol_for_type is used by the silhouette_samples function itself. I am not sure what to do:

  • we could change the silhouette tests to adjust for the new _atol_for_type for semantics;
  • or find a way to change the tols only for the array API tests without changing the other uses if _atol_for_type;
  • or rewrite silhouette_samples to not rely on _atol_for_type and instead do its own internal tol adjustment, and reserve _atol_for_type for testing purposes.

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Oct 10, 2025

I think I would be in favor of the 3rd option.

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Oct 10, 2025

@jeremiedbb if you have opinions on tol settings ;)

@jeremiedbb
Copy link
Copy Markdown
Member

I also think that option 3 is more appropriate. In silhouette_score it's used as an eps to filter out extremely small values, whereas in testing it's used a tol (close semantics but I find a small difference between the 2).
By the way I find that 100 * finfo.eps is quite low. I wouldn't mind increasing it a bit

@OmarManzoor
Copy link
Copy Markdown
Contributor

OmarManzoor commented Oct 10, 2025

I also think that we should remove this dependency from silhouette score and directly calculate atol using the original factor within the function while keeping _atol_for_type as it is in the latest commit where we increase the factor by 10.

@jeromedockes
Copy link
Copy Markdown
Contributor Author

@jeromedockes just checking, are you still interested in working on this?

I'm sorry for the late reply, @lucyleeow ! It's not for lack of interest, but unfortunately I really don't have the time at the moment. I should have said so earlier to avoid stalling it. I'm glad to see you picked it up @OmarManzoor , thanks!!

decision = self.decision_function(X)
xp, is_array_api, device_ = get_namespace_and_device(decision)
max_float_dtype = _max_precision_float_dtype(xp, device=device_)
scores = 2.0 * xp.astype(decision > 0, max_float_dtype) - 1.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The xp.astype(decision > 0, max_float_dtype) was previously hardcoded to
xp.astype(decision > 0, xp.float32) but I decided to use the max_float_dtype instead. Let me know if we should instead revert back to xp.float32

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @OmarManzoor, for pushing this to the finish line. A final pass of nitpicks but otherwise, LGTM.

@ogrisel ogrisel enabled auto-merge (squash) October 13, 2025 15:00
@ogrisel ogrisel merged commit 5f1491a into scikit-learn:main Oct 13, 2025
36 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Array API Oct 13, 2025
Tunahanyrd pushed a commit to Tunahanyrd/scikit-learn that referenced this pull request Oct 28, 2025
…fierCV (scikit-learn#27961)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

7 participants