[MRG+2] Fix LDA predict_proba() #11796
[MRG+2] Fix LDA predict_proba() #11796agramfort merged 26 commits intoscikit-learn:masterfrom agamemnonc:lda_predict_proba_fix
Conversation
jnothman
left a comment
There was a problem hiding this comment.
Yes, this needs non-regression tests
|
@jnothman do you mean a numerical test that checks that the output probabilities of a toy dataset are as expected (e.g. something similar to |
|
I suppose so. Something that fails at master, works in this PR, and is
illustrative of what we think correct behaviour should be.
|
|
Could you please let me know if the non-regression test looks OK? |
|
Can I suggest that this PR be prioritised given that it fixes a bug (#6848), which yields wrong prediction outcomes for a somehow popular classifier? |
|
Hello, I just would like to note that I am running into this bug in practice and would really appreciate a fix. |
|
Thanks for the pings, @agamemnonc, @taalexander. I'll look soon. |
jnothman
left a comment
There was a problem hiding this comment.
This LGTM. I only wonder whether this should be encapsulated in a logistic utility function.
sklearn/discriminant_analysis.py
Outdated
| # up to a multiplicative constant. | ||
| likelihood = np.exp(prob - prob.max(axis=1)[:, np.newaxis]) | ||
| # compute posterior probabilities | ||
| return likelihood / likelihood.sum(axis=1)[:, np.newaxis] |
There was a problem hiding this comment.
Why not continue to do this inplace (/=)?
|
Thanks @jnothman . Yes, you are right. Hopefully the most recent commit is much cleaner; it is also consistent with the I have also fixed a typo in |
|
Nice! |
|
Given the already merged #12931, now there is a a conflict. (ΒΤW, I believe this could have been avoided if this PR had been timely merged (PR submitted on December 5th, whereas #12931 submitted on January 6th)). Anyway, I suggest that the changes in the most recent commit are overwritten by the current PR , since the code in this PR inherits the method from the parent class ( |
|
Apologies about the poor management of related pull requests on our part. Please resolve conflicts with master so we can see the benefits of this pr more clearly |
|
OK, no problem. I have now provided a fix, since the test introduced in this PR ( Moreover, the suggested fix reuses code by inheriting from the parent class rather than re-implementing the method when
|
|
Please add a |
Done, thanks for the instructions. |
|
It would be good to get this in 0.21. Thanks @agamemnonc if you can get to it. |
|
Apologies for the delay, I have been very busy with a submission recently, will try to deal with this by the end of this week. |
|
Thank you :)
|
|
OK folks, I think I have now implemented everything we have agreed on. If tests pass, @agramfort, @jnothman, @glemaitre could you please have a final look and merge if happy or let me know otherwise. |
| n_samples=90000, centers=blob_centers, covariances=blob_stds, | ||
| random_state=42 | ||
| ) | ||
| lda = LinearDiscriminantAnalysis(solver='lsqr').fit(X, y) |
There was a problem hiding this comment.
do we want to test this for the other solvers as well? how long does the test take given the amount of samples?
There was a problem hiding this comment.
Good point.
Including the two other solvers adds only 0.09 s.
The test passes for solver=svd, but fails when solver=eigen.
This is probably related to #11727.
@amueller Shall we only include svd and lsqr for now in the tests and take a note in that other PR to update the tests to also include eigen when a fix is submitted?
I will try to also provide a fix for the eigen solver in this PR.
There was a problem hiding this comment.
OK, I think I have now fixed that other issue which was due to bad normalisation of the eigenvectors and was causing issues with probabilities for the eigen solver #11727 . I have updated the rst file accordingly.
Now all three solvers are tested in the non-regression test.
|
@glemaitre @jnothman you might need to re-approve, as I have modified the code in |
|
thx @agamemnonc |
|
A bit late but thanks a lot @agamemnonc |
My pleasure—thank you all for all your help and feedback. |
* fix LDA predict_proba() to handle binary and multi-class case * test_lda_predict_proba non-regression test * pep8 fix * lda predict_proba refactoring * Typo fix * flake8 fix * predict_proba check_is_fitted check * update what's new rst file * rename prob to decision * include additional tests for predict_proba * use allcose vs. assert_array_almost_equal * fix indent * replace len with size * explicit computation for binary case * fix style whats_new rst * predict_proba new regression test * give credit for regression test * fix bug for eigen solution * include all three solvers in predict_proba regression test * update whats_new rst file * fix minor formatting issue * use scipy.linalg instead of np.linalg
This reverts commit deea1e8.
This reverts commit deea1e8.
* fix LDA predict_proba() to handle binary and multi-class case * test_lda_predict_proba non-regression test * pep8 fix * lda predict_proba refactoring * Typo fix * flake8 fix * predict_proba check_is_fitted check * update what's new rst file * rename prob to decision * include additional tests for predict_proba * use allcose vs. assert_array_almost_equal * fix indent * replace len with size * explicit computation for binary case * fix style whats_new rst * predict_proba new regression test * give credit for regression test * fix bug for eigen solution * include all three solvers in predict_proba regression test * update whats_new rst file * fix minor formatting issue * use scipy.linalg instead of np.linalg
Reference Issues/PRs
Fixes #6848
closes #11727
closes #5149
What does this implement/fix? Explain your changes.
Fixes the
predict_proba()method of LinearDiscriminantAnalysis.An
ifstatement is used to differentiate between the binary and multi-class case, due to the different output format of thedecision_functionmethod implemented in theLinearClassifierMixinclass.Any other comments?
Copying from #6848:
Do we perhaps want to include additional tests checking the output of predict_proba for LDA and QDA both for the binary and multi-class cases?