Add multivariate normal distribution #651
Conversation
Current state of pyro
There was a problem hiding this comment.
This is great! We've been wanting a MultivariateNormal for a long time.
FYI over the next month we're moving all Pyro distributions implementations upstream to torch.distributions (progress at https://github.com/probtorch/pytorch/projects/1). @tbrx has been working on a torch.distributions.MultivariateNormal probtorch/pytorch#1. After your PR is merged we should coordinate to get this implementation moved upstream.
| Fixture(pyro_dist=(dist.multivariate_normal, MultivariateNormal), | ||
| scipy_dist=sp.multivariate_normal, | ||
| examples=[ | ||
| { |
There was a problem hiding this comment.
Could you please also add a Cholesky example?
|
|
||
| reparameterized = True | ||
|
|
||
| def __init__(self, mu, sigma, batch_size=None, is_cholesky=False, use_inverse_for_batch_log=False, *args, **kwargs): |
There was a problem hiding this comment.
Two comments on the interface:
- It would be clearer to provide two alternative arguments
sigma=Noneandchol_sigma=Noneand derive one from the other, rather than provide a singlesigmaand a flagis_cholesky=False. This is also more consistent with other distributions classes likeBernoulliandCategoricalwhich can take eitherprobs=Noneorlogits=Noneand derive one from the other internally. I suppose the more common should be first, but I don't know which is more common. - Going forward we're trying to maximize compatibility with Tensorflow Distributions, so it would be preferable to rename
mu->loc,sigma->covariance_matrix, andsigma_cholesky->scale_tril. For reference, see https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/contrib/distributions/MultivariateNormalCholesky and https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/contrib/distributions/MultivariateNormalFull
| transformed_sample = self.mu + torch.mm(uncorrelated_standard_sample, self.sigma_cholesky) | ||
| return transformed_sample if self.reparameterized else transformed_sample.detach() | ||
|
|
||
| def batch_log_pdf(self, x, normalized=True): |
There was a problem hiding this comment.
This normalized flag is nice to have. Right now we don't have the plumbing to support it as an argument to batch_log_pdf(), but we can support it as an argument to the __init__ method. Could you move it there and make an instance attribute .normalized?
| batch_log_pdf_shape = self.batch_shape(x) + (1,) | ||
| x = x.view(batch_size, *self.mu.size()) | ||
| normalization_factor = torch.log( | ||
| self.sigma_cholesky.diag()).sum() + (self.mu.size()[0] / 2) * np.log(2 * np.pi) if normalized else 0 |
There was a problem hiding this comment.
I guess the gradients will be incorrect if sigma_cholesky.requires_grad=True and normalized=False. Do you think it would be reasonable to either (1) raise a NotImplementedError in this case (to ensure users set sigma_cholesky.requires_grad=False) or (2) compute a numerically-stable quantity whose value is zero but whose gradient agrees with the gradient of the normalization factor (this is a research question 😉 )?
There was a problem hiding this comment.
I see the problem there, of course the gradient of the normalization coefficient is not taken into account at the moment. I am however not completely sure how much of a problem this is, as I would presume that the normalization might not be of real importance for many applications. I must admit that I have not thought too much about this problem until now, but I will look into it. Approach (2) seems better to me though, as (1) would disable differentiation for those who don't care about normalization too.
There was a problem hiding this comment.
Regarding (1) I have added a warning about the incorrect gradients to batch_log_pdf. This seems to be the best option to me, as not being able to differentiate if normalized is set to false would seriously limit the usefulness of this feature in my opinion.
There was a problem hiding this comment.
Approach (2) might also not be as difficult as it seems. The gradient of the determinant of X apparently is 2*X.inverse() - torch.diag(torch.diag(X.inverse())), a rather nice expression (see (141) in http://www.math.uwaterloo.ca/~hwolkowi//matrixcookbook.pdf) especially since the inverse is already being computed. I am however not sure how to inject that gradient into autograd, do you have any suggestions as to how that might be achieved efficiently?
There was a problem hiding this comment.
I think the easiest way is to create an autograd function:
from torch.autograd import Function
class _NonnormalizedDeterminate(Function):
@staticmethod
def forward(ctx, matrix):
ctx.save_for_backward(matrix)
return matrix.new([1.0]) # A bogus value.
@staticmethod
def backward(ctx, grad_output):
matrix, = ctx.saved_variables
inv = matrix.inverse()
grad = 2 * inv - torch.diag(torch.diag(inv))
return grad_output * gradYou can then use this as
normalization_factor = torch.log(_NonnormalizedDeterminant.apply(sef.sigma)) + ...There was a problem hiding this comment.
Also feel free to open an issue on this so we don't lose this extended discussion.
| return Variable(torch.potri(var.data)) if torch.__version__ < '0.3.0' else torch.potri(var) | ||
|
|
||
|
|
||
| class MultivariateNormal(Distribution): |
There was a problem hiding this comment.
nit: You can decorate with @copy_docs_from(Distribution) and then omit trivial method docstrings. See other distributions for usage.
|
cc @fehiepsi |
|
Thank you for the helpful comments. I made the changes you suggested to the interface and have added a test for directly passing the Cholesky decomposition of the covariance matrix. |
fritzo
left a comment
There was a problem hiding this comment.
Thanks for updating so quickly!
|
|
||
| reparameterized = True | ||
|
|
||
| def __init__(self, loc, covariance_matrix, scale_tril=None, batch_size=None, use_inverse_for_batch_log=False, |
There was a problem hiding this comment.
I think the easiest interface to use would be to require one of the args to be specified (as with Bernoulli):
def __init__(self, loc, covariance_matrix=None, scale_tril=None, ...):
if covariance_matrix is None and scale_tril is None:
raise ValueError('At least one of covariance_matrix or scale_tril must be specified')
if scale_tril is None:
covariance_matrix = ...
elif covariance_matrix is None:
scale_tril = ...There was a problem hiding this comment.
Okay, I am going to update that. That will make using the Nones in the test obsolete too.
| Fixture(pyro_dist=(dist.multivariate_normal, MultivariateNormal), | ||
| scipy_dist=sp.multivariate_normal, | ||
| examples=[ | ||
| {'loc': [2.0, 1.0], 'covariance_matrix': [[1.0, 0.5], [0.5, 1.0]], 'scale_tril':None, |
There was a problem hiding this comment.
Could you omit the None arguments below:
{'loc': [2.0, 1.0], 'covariance_matrix': [[1.0, 0.5], [0.5, 1.0]],
'test_data': [[2.0, 1.0], [9.0, 3.4]]},
{'loc': [2.0, 1.0], 'scale_tril': [[1.0, 0.5], [0, 3900231685776981/4503599627370496]],
'test_data': [[2.0, 1.0], [9.0, 3.4]]},| loc.size())) | ||
| self.sigma = covariance_matrix | ||
| self.sigma_cholesky = Variable( | ||
| torch.potrf(covariance_matrix.data)) if torch.__version__ < '0.3.0' else torch.potrf(covariance_matrix) |
There was a problem hiding this comment.
I like your potri_compat() function above. Could you factor this line out as a similar potrf_compat()?
There was a problem hiding this comment.
No problem, I am going to include that in the next commit.
| x = x.view(batch_size, *self.mu.size()) | ||
| normalization_factor = torch.log( | ||
| self.sigma_cholesky.diag()).sum() + (self.mu.size()[0] / 2) * np.log(2 * np.pi) if self.normalized else 0 | ||
| sigma_inverse = torch.inverse(self.sigma) if self.use_inverse_for_batch_log else \ |
There was a problem hiding this comment.
As I understand, it's generally a good idea to avoid computing an inverse when solving an equation: if you have a cholesky you should use a Cholesky solver, or if you have a matrix you should use a general matrix solver. From what you know of PyTorch, is this not yet possible?
There was a problem hiding this comment.
It looks like you could do this with torch.gesv(): http://pytorch.org/docs/master/torch.html#torch.gesv
There was a problem hiding this comment.
I was thinking that it is easier to use torch.inverse() because torch.gesv() is not batch solver.
There was a problem hiding this comment.
Yeah, that is really a problem. It looks like using torch.gesv() only brings performance gains if the batch size is considerably smaller than that of the covariance matrix.
|
What do you think of adding a matrix inverse helper for use in def linear_solve_compat(matrix, matrix_chol, x):
"""Solves the equation ``torch.mm(matrix, y) = x`` for y."""
assert matrix.requires_grad == matrix_chol.requires_grad
if matrix.requires_grad or x.requires_grad:
# If derivatives are required, use the more expensive gesv.
return torch.gesv(x, matrix)
else:
# Use the cheaper Cholesky solver.
return torch.potrs(x, matrix_chol)That would clean up the logic in (EDIT renamed matrix_solve_compat to linear_solve_compat) |
|
I think the |
|
So I have added a |
fritzo
left a comment
There was a problem hiding this comment.
Looks great. This looks ready to merge after a couple cosmetic changes.
| A distribution over vectors in which all the elements have a joint | ||
| Gaussian density. | ||
|
|
||
| :param torch.autograd.Variable loc: Mean. Must be a vector (Variable containing a 1d Tensor). |
There was a problem hiding this comment.
nit: Could you ensure that docstrings are at most 80 characters wide (here and in .sample() below)? This is important for users of Jupyter Notebooks so that help(MultivariateNormal) prints correctly.
| x = x.view(batch_size, *self.mu.size()) | ||
| normalization_factor = torch.log( | ||
| self.sigma_cholesky.diag()).sum() + (self.mu.size()[0] / 2) * np.log(2 * np.pi) if self.normalized else 0 | ||
| sigma_inverse = matrix_inverse_compat(self.sigma, self.sigma_cholesky) |
There was a problem hiding this comment.
Your review comment was helpful in understanding how we'll update this code in the future. Could you insert that comment here as a # TODO It may be useful to switch between matrix_inverse_compat() and linear_solve_compat() ...?
|
|
||
| def batch_log_pdf(self, x): | ||
| if not self.normalized and self.sigma_cholesky.requires_grad: | ||
| warnings.warn("Gradients will not take normalization into account if normalized=False.") |
There was a problem hiding this comment.
I'm inclined to raise an error rather than warn because I do not understand what is being computed in this case. Do you have a clear understanding of what is being computed and whether there are any statistical applications for computing this new quantity?
There was a problem hiding this comment.
Of course the gradients would only be a kind of rather bad estimate of the true gradient, but it looked to me like it was good enough to do some things like fitting the parameters of a kernel for a gaussian process. But I did investigate that very thoroughly so I will change the code to raise an error for the moment. It looks like (2) can be implemented very soon anyway.
There was a problem hiding this comment.
Thanks, raising an error seems like the right thing to do, and you've convinced me that a correct solution should be fairly cheap.
|
I made the cosmetic changes so if this gets merged I am going to open an issue as you suggested to further discuss the remaining details. I will try to implement the backwards pass for the normalization constant over the next couple of days. |
|
Great, I'll merge as soon as tests pass and we can start implementing that cheap nonnormalized gradient. |
|
Hi @dwd31415 could you help me understand a bit of the self.sigma = torch.mm(scale_tril.transpose(0, 1), scale_tril)(I would expect |
|
Oh I see, we named our variable |
|
Yeah, the name really is misleading. I must confess that I did not think about that when I renamed things to fit the tensorflow api. I think the best way to go forward would be to change the name, as |
|
Hmm, it would be nice maintain compatibility with tensorflow and pass |
|
Yeah, these changes should be fairly easy to implement. Support for batched parameters also would be really nice to have. But although batched |
|
I was wondering about batched I think that batched But I'm actually not sure what we should do about sampling, without batched |
|
How about we require the |
|
I guess that that would be a possibility, but the functionality would be seriously limited. |
|
At least on the CPU it would be possible to fall back to |
This pull request adds a (basic) implementation of a multivariate normal distribution.
Since
torch.potrfis differentiable as of v0.3.0 this can now be done in a manner that is differentiable using Autograd. This is at least true for the samples. As the derivative fortorch.potriis not implemented yet the results of log_pdf can not differentiated as of now, when they are computed using the Cholesky decomposition. This pull request however includes an additional implementation of the log pdf that uses thetorch.inversefunction and is differentiable although obviously much slower.The pull request already includes some tests, however a workaround is used to pass the parameters to the scipy implementation because they were not passed correctly by
get_scipy_batch_logpdf.