Add option to change covariance matrix type for GMM class#50
Add option to change covariance matrix type for GMM class#50francois-rozet merged 15 commits intoprobabilists:masterfrom
Conversation
|
Hello @dominik-strutz, I quickly went over the code, and it looks nice! I think we should add some tests however, maybe in a new Are you also planning to improve the initialization as well? |
|
Hi @francois-rozet, Yes, I am happy to write some tests. I'm also happy to try to improve the initialization. Following What is your opinion on how to structure the initialization? I think it would be beneficial to keep the I will give the initialisation a try and let you know how it goes. P.S: I have no idea why the pre-commit hook fails. I used |
I think a good way to handle the conditional case would be to make the weight
I agree that a separate method could be appropriate, similar to the
I pulled your branch and Maybe you were not at the root? My version of ruff is |
|
@dominik-strutz Do you still plan on contributing this PR? |
|
Yes, I still like to contribute but haven't found much free time to do it recently. I have implemented most of the initialization methods for the unconditional case, but it still needs to be polished up and tested. The extension for the conditional case shouldn't take too long afterwards. If you or someone else wants to continue this sooner, I'm happy to push an intermediary commit of everything I have so far. |
|
No problem, take your time! I am currently updating a few things and wanted to know if I should wait for this PR for the next minor release. |
0a176b5 to
dedc708
Compare
|
It seems like I closed this automatically when I rebased my fork. Is there a way to reopen it with the same master branch of my fork or do I need to open a new request? In short, I have added an initialisation method that works with and without context variables and added some tests. |
|
The pull request should be ready to merge now. It adds the following new functionalities:
Let me know if this is up to your standard or if there are any changes necessary :) PS.: I uploaded two jupyter notebooks which I used to test the functionality to a separate branch in my fork of this repository: https://github.com/dominik-strutz/zuko/tree/test_gmm. |
|
Hi @dominik-strutz, sorry for the delay. I finally had the time to go over the PR. Overall, it is very good! The initialization in the conditional case was very smart 🧠 I modified a few things:
I tested the code, and it works really well. I will merge the PR soon. Thanks again! |
This PR adds changes to the
zuko.flows.mixture.GMMclass, which allow the user to change the type of the covariance matrix used for each of the Gaussian components of the mixture.The options added are
covariance_type, which allows to change the type of the covariance matricestieda switch which allows to control if covariance matrices are tied between componentscov_rankthe rank of the low-rank covariance matrix whencovariance_typeis 'lowrank'Since the construction of the shapes got quite long I moved this part in its own function.
Below is an illustration of the effect these different choices have for a mixture of 3 two-dimensional Gaussians.