ENH Gaussian mixture bypassing unnecessary initialization computing#26021
ENH Gaussian mixture bypassing unnecessary initialization computing#26021OmarManzoor merged 21 commits intoscikit-learn:mainfrom
Conversation
thomasjpfan
left a comment
There was a problem hiding this comment.
Thank you for the PR @jiawei-zhang-a !
Please add an entry to the change log at doc/whats_new/v1.3.rst with tag |Efficiency|. Like the other entries there, please reference this pull request with :pr: and credit yourself (and other contributors if applicable) with :user:.
|
Thank you Mr. Fan @thomasjpfan . I have removed the new state and a new changelog :) |
thomasjpfan
left a comment
There was a problem hiding this comment.
Thank you for the PR @jiawei-zhang-a ! We still need a test to make sure that the parameters are not estimated during initialization. I think a simple way is to monkeypatching:
def test_gaussian_mixture_all_init_does_not_estimate_gaussian_parameters(monkeypatch):
"""When all init are provided, the Gaussian parameters are not estimated.
Non-regression test for gh26015.
"""
mock = Mock(side_effect=_estimate_gaussian_parameters)
monkeypatch.setattr(
sklearn.mixture._gaussian_mixture, "_estimate_gaussian_parameters", mock
)
rng = np.random.RandomState(0)
rand_data = RandomData(rng)
gm = GaussianMixture(
n_components=rand_data.n_components,
weights_init=rand_data.weights,
means_init=rand_data.means,
precisions_init=rand_data.precisions["full"],
random_state=rng,
)
gm.fit(rand_data.X["full"])
# The initial gaussian parameters are not estimated. They are estimated for every
# m_step.
assert mock.call_count == gm.n_iter_Mock is from Python's untitest.mock module.
On main, the test would fail where mock.call_count is gm.n_iter_+1 from the extra call during initialization.
| :user:`Jérémie du Boisberranger <jeremiedbb>`, | ||
| :user:`Guillaume Lemaitre <glemaitre>`. | ||
|
|
||
|
|
There was a problem hiding this comment.
For git blame purpose, can you revert this?
|
Dear Mr.Fan, Thank you so much for all the advice! I will do it following your words! |
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
thomasjpfan
left a comment
There was a problem hiding this comment.
A few minor comments, otherwise LGTM
| from sklearn.utils._testing import assert_array_equal | ||
| from sklearn.utils._testing import ignore_warnings | ||
|
|
||
| from unittest.mock import Mock |
There was a problem hiding this comment.
Nit: Can you move this import to line 9 below import warning? This way the "first party Python modules" are at the top of the file.
There was a problem hiding this comment.
Sure! I will do that immediately
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
…cikit-learn into GaussianMixture
|
Good job! Waiting for this to merge |
OmarManzoor
left a comment
There was a problem hiding this comment.
Thanks for the PR @jiawei-zhang-a. Could you kindly resolve the conflicts by merging main and have a look at these few comments?
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
|
@OmarManzoor Thank you so much for your review! I have committed your suggestions and fix the conflict with main branch |
OmarManzoor
left a comment
There was a problem hiding this comment.
Thanks for the updates. I added a few more comments otherwise this looks good now!
|
Thank you so much! I |
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
|
Sure! I will check that now |
Reference Issues/PRs
Fixes #26015
What does this implement/fix? Explain your changes.
I add a private variable
_init_weights_means_precisions_skippedin _base.py.if a user is passing some initial values for the weights, means, and precision then there is no need to run the initialization (via K-means or random) to estimate the gaussian parameters.
These two steps are now skipped if
_init_weights_means_precisions_skippedisTrueAny other comments?