[MRG+1] fix plot_partial_dependence not taking target into account when multiclass#14393
[MRG+1] fix plot_partial_dependence not taking target into account when multiclass#14393glemaitre merged 8 commits intoscikit-learn:masterfrom GuillemGSubies:bugfix_partial_dependence_plot_multiclass
Conversation
|
cc @NicolasHug |
If |
|
You can add a test for it in fact (in another PR). Quickly the test should be something like: from sklearn.datasets import fetch_openml
from sklearn.tree import DecisionTreeClassifier
from sklearn.inspection import plot_partial_depedence
iris = fetch_openml('iris', as_frame=True, version=1)
df, y = iris.data, iris.target.to_numpy()
clf = DecisionTreeClassifier().fit(df, y)
assert dtype(clf.classes_) == 'object'
# check that the pdp with str and int give the same results
# pick-up the last class
# implement the assert as in this PR
plot_partial_dependence(clf, df, [0], target='Iris-viriginica')
plot_partial_dependence(clf, df, [0], target=2) |
Actually I asked because there is already a test about that |
Oh perfect then, so no need for an additional test ;) |
NicolasHug
left a comment
There was a problem hiding this comment.
Small comments but LGTM anywway.
Thanks for the fix @GuillemGSubies !
| # check that the pd plots are the same for 0 and "setosa" | ||
| assert all(axs[0].lines[0]._y == axs2[0].lines[0]._y) | ||
| # check that the pd plots are different for another target | ||
| clf = GradientBoostingClassifier(n_estimators=10, random_state=1) |
There was a problem hiding this comment.
I think you can remove a few lines, namely the clf definition and fitting, as well as the grid_resolution.
There was a problem hiding this comment.
Ok, I will. I did not change those because I did not know if it had to be with some standard you use when testing.
NicolasHug
left a comment
There was a problem hiding this comment.
Small comments but LGTM anywway.
Thanks for the fix @GuillemGSubies !
Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com>
…ithub.com/GuillemGSubies/scikit-learn into bugfix_partial_dependence_plot_multiclass
| # check that the pd plots are the same for 0 and "setosa" | ||
| assert all(axs[0].lines[0]._y == axs2[0].lines[0]._y) | ||
| # check that the pd plots are different for another target | ||
| clf.fit(iris.data, iris.target) |
There was a problem hiding this comment.
you can remove this line too ;)
There was a problem hiding this comment.
Looks like I shouldn't have removed it. That means that if I train using the targets as strings, I cannot pass an int to plot_partial_dependence
Don't know if that is the expected behavior or not
There was a problem hiding this comment.
Oh OK, my bad, I didn't realize it was fit on something different before
|
Thanks @GuillemGSubies |
Reference Issues/PRs
Fixes #14301
What does this implement/fix? Explain your changes.
I just took out an else so
target_idxdoes not get overwritten.Any other comments?
I didn't know what was the optimal way to test it. Right now I check the y axis and make sure that they are not the same (the bug made them equals all the time).
Also, I have a question: Here it should be int or str, shouldn't it?
scikit-learn/sklearn/inspection/partial_dependence.py
Line 404 in c0c5313