Fix _safe_indexing with non integer arrays on array API inputs#32840
Fix _safe_indexing with non integer arrays on array API inputs#32840ogrisel merged 2 commits intoscikit-learn:mainfrom
_safe_indexing with non integer arrays on array API inputs#32840Conversation
|
I tried on google colab with scikit-learn 1.7.2 but the error did not occur using the original snippet in the issue. |
|
Since this bug did not exist in a previous release, I don't think we need to document this fix in the changelog. |
OmarManzoor
left a comment
There was a problem hiding this comment.
LGTM. Thank you @ogrisel
sklearn/utils/_indexing.py
Outdated
| else: | ||
| return array[key, ...] if axis == 0 else array[:, key] |
There was a problem hiding this comment.
Do we need this here considering that this is the default return statement anyways at the end?
There was a problem hiding this comment.
I ran the CI locally with the else block commented out, and it passed with CUDA. Even the code mentioned in #32837 compiled without any issues. So I think it's safe to remove it here.
sklearn/utils/_indexing.py
Outdated
| else: | ||
| return array[key, ...] if axis == 0 else array[:, key] |
There was a problem hiding this comment.
I ran the CI locally with the else block commented out, and it passed with CUDA. Even the code mentioned in #32837 compiled without any issues. So I think it's safe to remove it here.
|
I will open a follow-up PR to try to add a higher level estimator common test to discover other similar bugs in the future but that might take more time to get in and don't want to delay this fix and the release of 1.8. |
PR: scikit-learn#32846 Issue: scikit-learn#32840 Base commit: 8061a39 Changed lines: 210
This is a fix for #32837.
While investigating the issue above, I realized that we needed unittest for array API support for
_safe_indexing.I am not yet sure if this problem already existing in 1.7 or not. If it was I will add a changelog entry.