Skip to content

Add covariance_type option to GMM class #49

@dominik-strutz

Description

@dominik-strutz

Description

Currently, the zuko.flows.mixture.GMM class only supports full covariance matrices. However, there are a number of use cases (especially high-dimensional) where a full covariance matrix is either not needed or infeasible to estimate. This issue proposes to add the option to choose between different covariance matrix types similar to sklearn.mixture.GaussianMixture

Here is an example of how different covariance types approximate a mixture of 3 Gaussians with varying covariance matrices.

c6263e0a-5d78-45ef-8df4-414692a7871f

Implementation

The current structure of the GMM zuko.flows.mixture.GMM class makes it very easy to add the above mentioned enhancements. I have implemented the changes in a fork of the repository and could open a pull request if this change is wanted. I have only tested the code for the unconditional case, but I do not see any way I could break it when adding context features.

Further improvements

When generating the above figure, I (again) realised how easily mode collapse happens for GMMs. The zuko.flows.mixture.GMM class could, therefore, also benefit from some sort of initialisation procedure, again, similar to sklearn.mixture.GaussianMixture. I fully understand if that goes beyond the scope of what Zuko wants to achieve. The benefit is that Zuko is very convenient to use and ties in so well with Pytorch code that having such a procedure here could be nice. However, it might add another dependency (e.g., sklearn) if you want to use existing implementations of initialisation algorithms. I have some basic implementation of this (using sklearn) lying around and would be happy to polish it up and make another commit if this is wanted.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions