Skip to content

Add multivariate normal distribution #651

Merged
fritzo merged 29 commits intopyro-ppl:devfrom
dwd31415:dev
Dec 29, 2017
Merged

Add multivariate normal distribution #651
fritzo merged 29 commits intopyro-ppl:devfrom
dwd31415:dev

Conversation

@dwd31415
Copy link
Copy Markdown
Contributor

This pull request adds a (basic) implementation of a multivariate normal distribution.

Since torch.potrfis differentiable as of v0.3.0 this can now be done in a manner that is differentiable using Autograd. This is at least true for the samples. As the derivative for torch.potri is not implemented yet the results of log_pdf can not differentiated as of now, when they are computed using the Cholesky decomposition. This pull request however includes an additional implementation of the log pdf that uses the torch.inverse function and is differentiable although obviously much slower.

The pull request already includes some tests, however a workaround is used to pass the parameters to the scipy implementation because they were not passed correctly by get_scipy_batch_logpdf.

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Dec 23, 2017

CLA assistant check
All committers have signed the CLA.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Dec 26, 2017

cc @karalets @tbrx

Copy link
Copy Markdown
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! We've been wanting a MultivariateNormal for a long time.

FYI over the next month we're moving all Pyro distributions implementations upstream to torch.distributions (progress at https://github.com/probtorch/pytorch/projects/1). @tbrx has been working on a torch.distributions.MultivariateNormal probtorch/pytorch#1. After your PR is merged we should coordinate to get this implementation moved upstream.

Comment thread tests/distributions/conftest.py Outdated
Fixture(pyro_dist=(dist.multivariate_normal, MultivariateNormal),
scipy_dist=sp.multivariate_normal,
examples=[
{
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add a Cholesky example?


reparameterized = True

def __init__(self, mu, sigma, batch_size=None, is_cholesky=False, use_inverse_for_batch_log=False, *args, **kwargs):
Copy link
Copy Markdown
Member

@fritzo fritzo Dec 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two comments on the interface:

  1. It would be clearer to provide two alternative arguments sigma=None and chol_sigma=None and derive one from the other, rather than provide a single sigma and a flag is_cholesky=False. This is also more consistent with other distributions classes like Bernoulli and Categorical which can take either probs=None or logits=None and derive one from the other internally. I suppose the more common should be first, but I don't know which is more common.
  2. Going forward we're trying to maximize compatibility with Tensorflow Distributions, so it would be preferable to rename mu -> loc, sigma -> covariance_matrix, and sigma_cholesky -> scale_tril. For reference, see https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/contrib/distributions/MultivariateNormalCholesky and https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/contrib/distributions/MultivariateNormalFull

transformed_sample = self.mu + torch.mm(uncorrelated_standard_sample, self.sigma_cholesky)
return transformed_sample if self.reparameterized else transformed_sample.detach()

def batch_log_pdf(self, x, normalized=True):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This normalized flag is nice to have. Right now we don't have the plumbing to support it as an argument to batch_log_pdf(), but we can support it as an argument to the __init__ method. Could you move it there and make an instance attribute .normalized?

batch_log_pdf_shape = self.batch_shape(x) + (1,)
x = x.view(batch_size, *self.mu.size())
normalization_factor = torch.log(
self.sigma_cholesky.diag()).sum() + (self.mu.size()[0] / 2) * np.log(2 * np.pi) if normalized else 0
Copy link
Copy Markdown
Member

@fritzo fritzo Dec 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the gradients will be incorrect if sigma_cholesky.requires_grad=True and normalized=False. Do you think it would be reasonable to either (1) raise a NotImplementedError in this case (to ensure users set sigma_cholesky.requires_grad=False) or (2) compute a numerically-stable quantity whose value is zero but whose gradient agrees with the gradient of the normalization factor (this is a research question 😉 )?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the problem there, of course the gradient of the normalization coefficient is not taken into account at the moment. I am however not completely sure how much of a problem this is, as I would presume that the normalization might not be of real importance for many applications. I must admit that I have not thought too much about this problem until now, but I will look into it. Approach (2) seems better to me though, as (1) would disable differentiation for those who don't care about normalization too.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If (2) is difficult, I think it's fine to merge (1) for now and implement (2) later once we know how. I'd like to merge your PR soon since @fehiepsi is already building a Gaussian Process tutorial #650 and it would be nice if he could base that on your work.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding (1) I have added a warning about the incorrect gradients to batch_log_pdf. This seems to be the best option to me, as not being able to differentiate if normalized is set to false would seriously limit the usefulness of this feature in my opinion.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approach (2) might also not be as difficult as it seems. The gradient of the determinant of X apparently is 2*X.inverse() - torch.diag(torch.diag(X.inverse())), a rather nice expression (see (141) in http://www.math.uwaterloo.ca/~hwolkowi//matrixcookbook.pdf) especially since the inverse is already being computed. I am however not sure how to inject that gradient into autograd, do you have any suggestions as to how that might be achieved efficiently?

Copy link
Copy Markdown
Member

@fritzo fritzo Dec 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the easiest way is to create an autograd function:

from torch.autograd import Function

class _NonnormalizedDeterminate(Function):
    @staticmethod
    def forward(ctx, matrix):
        ctx.save_for_backward(matrix)
        return matrix.new([1.0])  # A bogus value.

    @staticmethod
    def backward(ctx, grad_output):
        matrix, = ctx.saved_variables
        inv = matrix.inverse()
        grad = 2 * inv - torch.diag(torch.diag(inv))
        return grad_output * grad

You can then use this as

normalization_factor = torch.log(_NonnormalizedDeterminant.apply(sef.sigma)) + ...

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also feel free to open an issue on this so we don't lose this extended discussion.

return Variable(torch.potri(var.data)) if torch.__version__ < '0.3.0' else torch.potri(var)


class MultivariateNormal(Distribution):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can decorate with @copy_docs_from(Distribution) and then omit trivial method docstrings. See other distributions for usage.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Dec 26, 2017

cc @fehiepsi

@dwd31415
Copy link
Copy Markdown
Contributor Author

Thank you for the helpful comments. I made the changes you suggested to the interface and have added a test for directly passing the Cholesky decomposition of the covariance matrix.

Copy link
Copy Markdown
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating so quickly!


reparameterized = True

def __init__(self, loc, covariance_matrix, scale_tril=None, batch_size=None, use_inverse_for_batch_log=False,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the easiest interface to use would be to require one of the args to be specified (as with Bernoulli):

def __init__(self, loc, covariance_matrix=None, scale_tril=None, ...):
    if covariance_matrix is None and scale_tril is None:
        raise ValueError('At least one of covariance_matrix or scale_tril must be specified')
    if scale_tril is None:
        covariance_matrix = ...
    elif covariance_matrix is None:
        scale_tril = ...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I am going to update that. That will make using the Nones in the test obsolete too.

Comment thread tests/distributions/conftest.py Outdated
Fixture(pyro_dist=(dist.multivariate_normal, MultivariateNormal),
scipy_dist=sp.multivariate_normal,
examples=[
{'loc': [2.0, 1.0], 'covariance_matrix': [[1.0, 0.5], [0.5, 1.0]], 'scale_tril':None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you omit the None arguments below:

{'loc': [2.0, 1.0], 'covariance_matrix': [[1.0, 0.5], [0.5, 1.0]],
 'test_data': [[2.0, 1.0], [9.0, 3.4]]},
{'loc': [2.0, 1.0], 'scale_tril': [[1.0, 0.5], [0, 3900231685776981/4503599627370496]],
 'test_data': [[2.0, 1.0], [9.0, 3.4]]},

loc.size()))
self.sigma = covariance_matrix
self.sigma_cholesky = Variable(
torch.potrf(covariance_matrix.data)) if torch.__version__ < '0.3.0' else torch.potrf(covariance_matrix)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your potri_compat() function above. Could you factor this line out as a similar potrf_compat()?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, I am going to include that in the next commit.

x = x.view(batch_size, *self.mu.size())
normalization_factor = torch.log(
self.sigma_cholesky.diag()).sum() + (self.mu.size()[0] / 2) * np.log(2 * np.pi) if self.normalized else 0
sigma_inverse = torch.inverse(self.sigma) if self.use_inverse_for_batch_log else \
Copy link
Copy Markdown
Member

@fritzo fritzo Dec 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand, it's generally a good idea to avoid computing an inverse when solving an equation: if you have a cholesky you should use a Cholesky solver, or if you have a matrix you should use a general matrix solver. From what you know of PyTorch, is this not yet possible?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you could do this with torch.gesv(): http://pytorch.org/docs/master/torch.html#torch.gesv

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that it is easier to use torch.inverse() because torch.gesv() is not batch solver.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that is really a problem. It looks like using torch.gesv() only brings performance gains if the batch size is considerably smaller than that of the covariance matrix.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Dec 26, 2017

What do you think of adding a matrix inverse helper for use in batch_log_pdf(), something like

def linear_solve_compat(matrix, matrix_chol, x):
    """Solves the equation ``torch.mm(matrix, y) = x`` for y."""
    assert matrix.requires_grad == matrix_chol.requires_grad
    if matrix.requires_grad or x.requires_grad:
        # If derivatives are required, use the more expensive gesv.
        return torch.gesv(x, matrix)
    else:
        # Use the cheaper Cholesky solver.
        return torch.potrs(x, matrix_chol)

That would clean up the logic in batch_log_pdf() and make it easier to update once a derivative is available.

(EDIT renamed matrix_solve_compat to linear_solve_compat)

@dwd31415
Copy link
Copy Markdown
Contributor Author

I think the matrix_inverse_compat function is a very good idea. I was not aware yet, that torch.gesv() supports differentiation too. I would however change the name to something like linear_solve_compat since it does not really compute the inverse.

@dwd31415
Copy link
Copy Markdown
Contributor Author

dwd31415 commented Dec 27, 2017

So I have added a matrix_inverse_compatfunction that uses inverse and porti for now. In the future it might be useful to switch between this and using a linear solver based on the batch_size and the size of the covariance matrix. I am not sure if the performance gains are really worth the effort though, but I will look into it.

fritzo
fritzo previously approved these changes Dec 27, 2017
Copy link
Copy Markdown
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. This looks ready to merge after a couple cosmetic changes.

A distribution over vectors in which all the elements have a joint
Gaussian density.

:param torch.autograd.Variable loc: Mean. Must be a vector (Variable containing a 1d Tensor).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could you ensure that docstrings are at most 80 characters wide (here and in .sample() below)? This is important for users of Jupyter Notebooks so that help(MultivariateNormal) prints correctly.

x = x.view(batch_size, *self.mu.size())
normalization_factor = torch.log(
self.sigma_cholesky.diag()).sum() + (self.mu.size()[0] / 2) * np.log(2 * np.pi) if self.normalized else 0
sigma_inverse = matrix_inverse_compat(self.sigma, self.sigma_cholesky)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your review comment was helpful in understanding how we'll update this code in the future. Could you insert that comment here as a # TODO It may be useful to switch between matrix_inverse_compat() and linear_solve_compat() ...?


def batch_log_pdf(self, x):
if not self.normalized and self.sigma_cholesky.requires_grad:
warnings.warn("Gradients will not take normalization into account if normalized=False.")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to raise an error rather than warn because I do not understand what is being computed in this case. Do you have a clear understanding of what is being computed and whether there are any statistical applications for computing this new quantity?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course the gradients would only be a kind of rather bad estimate of the true gradient, but it looked to me like it was good enough to do some things like fitting the parameters of a kernel for a gaussian process. But I did investigate that very thoroughly so I will change the code to raise an error for the moment. It looks like (2) can be implemented very soon anyway.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, raising an error seems like the right thing to do, and you've convinced me that a correct solution should be fairly cheap.

@dwd31415
Copy link
Copy Markdown
Contributor Author

I made the cosmetic changes so if this gets merged I am going to open an issue as you suggested to further discuss the remaining details. I will try to implement the backwards pass for the normalization constant over the next couple of days.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Dec 29, 2017

Great, I'll merge as soon as tests pass and we can start implementing that cheap nonnormalized gradient.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Jan 12, 2018

Hi @dwd31415 could you help me understand a bit of the MultivariateNormal code? I'm starting to refactor MultivariateNormal and I noticed it defines the opposite of what I would expect for the covariance:

self.sigma = torch.mm(scale_tril.transpose(0, 1), scale_tril)

(I would expect Sigma = L L'). Is this because of a PyTorch convention, or is this a bug? Thanks for any help!

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Jan 12, 2018

Oh I see, we named our variable scale_tril indicating lower-triangular matrix, but it's acutally an upper triangular matrix. So the code is correct but the variable is misnamed 😄

@dwd31415
Copy link
Copy Markdown
Contributor Author

Yeah, the name really is misleading. I must confess that I did not think about that when I renamed things to fit the tensorflow api. I think the best way to go forward would be to change the name, as torch.potrf returns an upper-triangular matrix by default. So I guess that that would fit better into PyTorch as a whole.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Jan 12, 2018

Hmm, it would be nice maintain compatibility with tensorflow and pass upper=False to torch.potrf. I'll lump that change into some upcoming changes we're doing to support more batching (since batch torch.gesv is on on the way).

@dwd31415
Copy link
Copy Markdown
Contributor Author

Yeah, these changes should be fairly easy to implement. Support for batched parameters also would be really nice to have. But although batched torch.gesv is great, I think we would also need batched torch.potrf. Or do you know of any way to get around that?

@tbrx
Copy link
Copy Markdown

tbrx commented Jan 13, 2018

I was wondering about batched torch.potrf and torch.potrs as well.

I think that batched torch.gesv will be enough for computing the .log_prob(); if it is like the current torch.gesv it should return not just the solutions but also a LU decomposition, which we can use to then quickly compute the log determinant.

But I'm actually not sure what we should do about sampling, without batched torch.potrf. Something like torch.symeig or similar would also work fine, but I don't think any of these are batched either…

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Jan 13, 2018

How about we require the scale_tril argument if batching is required, and raise NotImplementedError in case a user provides a batched covariance_matrix?

@dwd31415
Copy link
Copy Markdown
Contributor Author

I guess that that would be a possibility, but the functionality would be seriously limited.

@dwd31415
Copy link
Copy Markdown
Contributor Author

At least on the CPU it would be possible to fall back to numpy.linalg.cholesky or numpy.linalg.eig, since they both support batched inputs. Of course that would make differentiation via autograd impossible and completely lack support for CUDA tensors. So it's not a really good solution but at least it could be a temporary workaround. I also don't really see a good alternative here, a distributions that does not support sampling is not really useful either and it would at least allow us to implement most things right away and then switch them to a pure PyTorch implementation as soon as the necessary functions become available.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants