[distributions] Low rank multivariate normal#8635
[distributions] Low rank multivariate normal#8635fehiepsi wants to merge 21 commits intopytorch:masterfrom
Conversation
| loc_shape = batch_shape + event_shape | ||
| self.loc = loc.expand(loc_shape) | ||
| self.scale_factor = scale_factor.expand(loc_shape + scale_factor.shape[-1:]) | ||
| self.scale_diag = scale_diag.expand(loc_shape) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| self.covariance_matrix = _batch_inverse(precision_matrix) | ||
| batch_shape = _get_batch_shape(precision_matrix, loc) | ||
| self.precision_matrix = precision_matrix.expand(batch_shape + event_shape + event_shape) | ||
| self.covariance_matrix = _batch_inverse(precision_matrix) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| batch_shape = _get_batch_shape(L, x) | ||
| x = x.expand(batch_shape + (n,)) | ||
| L = L.expand(batch_shape + (n, n)) | ||
| return _batch_trtrs_lower(x.unsqueeze(-1), L).squeeze(-1).pow(2).sum(-1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Details============================= test session starts ============================== platform linux -- Python 3.5.5, pytest-3.6.0, py-1.5.3, pluggy-0.6.0 -- /home/fehiepsi/miniconda3/envs/pyro/bin/python cachedir: .pytest_cache benchmark: 3.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=5.00us max_time=1.00s calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: /home/fehiepsi/pytorch, inifile: plugins: xdist-1.22.2, forked-0.2, cov-2.5.1, benchmark-3.0.0, nbval-0.9.1 collecting ... collected 166 items / 155 deselectedtest/test_distributions.py::TestDistributions::test_lowrank_multivariate_normal_log_prob PASSED [ 9%] =================================== FAILURES =================================== self = <test_distributions.TestDistributions testMethod=test_multivariate_normal_shape>
test/test_distributions.py:1604: test/test_distributions.py:620: in _gradcheck_log_prob msg = 'Jacobian mismatch for output 0 with respect to input 1,\nnumerical:tensor([[-0.8790, -0.8150, -0.7580, -0.5903, -0.87... [ -3.5956, -12.2547, -3.2756, -4.5524, 2.5408],\n [ 2.0588, 12.0112, 1.2863, -0.1552, -3.4333]])\n'
E RuntimeError: Jacobian mismatch for output 0 with respect to input 1, torch/autograd/gradcheck.py:172: RuntimeError |
|
The gradient seems exactly 5x off. |
|
I think that I have found the bug at #8649 |
|
|
||
|
|
||
| def _batch_capacitance_tril(W, D): | ||
| r""" |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Although logically, this is supposed to be faster, I am interested in knowing what kind of speedup is obtained in the calculation of the Mahalanobis distance and |
|
@vishwakftw Here is some comparison https://gist.github.com/fehiepsi/80810668bae0fbaa926e5b6319c1ecd3 Using trtrs is much faster than inverse when the size of matrix is large. |
|
The timings are amazing!! Thanks!! |
|
@fehiepsi For GPU, you should always |
|
Regarding the |
|
@vishwakftw Since the gradients are exactly 5x off, I doubt it is a precision issue. |
|
Ah yes, sorry - missed that detail above. |
|
About the gradcheck which is 5x off, it seems that we have caught the issue at #8649 and @ssnl will take a look for it. I am wondering about two other issues. The first is if Another issue is when to do "broadcasting". Different from other distributions, it seems that it is better to avoid using Regarding "broadcasting", to solve Linv_y, I broadcast L to match the batch shape of y, then use @fritzo I think that you will have better ideas than me on these issues. Could you please give some thoughts on them? |
|
@fehiepsi The names seem fine; it would be nice to match tensorflow so that code is more portable. Re: broadcasting, I've faced similar issues and haven't yet found a good solution. If you find a good pattern here, we may want to copy it elsewhere, e.g. in computing a batch of VonMises normalizing factors when Here's a proposal:
We may additionally want to store |
| scale_diag_sqrt_unsqueeze = self.scale_diag.sqrt().unsqueeze(-1) | ||
| Dinvsqrt_W = self.scale_factor / scale_diag_sqrt_unsqueeze | ||
| K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous() | ||
| K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_distributions.py
Outdated
|
|
||
| x = dist1.sample((10,)) | ||
| expected = ref_dist.logpdf(x.numpy()) | ||
| print(dist1.log_prob(x), MultivariateNormal(mean, cov).log_prob(x)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
fritzo
left a comment
There was a problem hiding this comment.
Tests look great, I'll make one more pass over the math...
|
|
||
| Args: | ||
| loc (Tensor): mean of the distribution | ||
| scale_factor (Tensor): factor part of low-rank form of covariance matrix |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| scale_diag (Tensor): diagonal part of low-rank form of covariance matrix | ||
|
|
||
| Note: | ||
| The computation for determinant and inverse of covariance matrix is saved when |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| def __init__(self, loc, scale_factor, scale_diag, validate_args=None): | ||
| if loc.dim() < 1: | ||
| loc = loc.unsqueeze(0) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| """ | ||
| n = bvec.size(-1) | ||
| flat_bvec = bvec.reshape(-1, n) | ||
| flat_bmat = torch.stack([v.diag() for v in flat_bvec]) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| the log determinant. | ||
| """ | ||
| if capacitance_tril is None: | ||
| capacitance_tril = _batch_capacitance_tril(W, D) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`. | ||
| """ | ||
| if capacitance_tril is None: | ||
| capacitance_tril = _batch_capacitance_tril(W, D) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_distributions.py
Outdated
| 'scale_diag': torch.tensor([2.0, 0.25], requires_grad=True), | ||
| }, | ||
| { | ||
| 'loc': torch.randn(2, 3, requires_grad=True), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_distributions.py
Outdated
| scale_diag[i].diag() for i in range(0, 2)] | ||
| p = LowRankMultivariateNormal(loc[0], scale_factor[0], scale_diag[0]) | ||
| q = LowRankMultivariateNormal(loc[1], scale_factor[1], scale_diag[1]) | ||
| actual = kl_divergence(p, q) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| term2 = _batch_trace_XXT(torch.matmul(_batch_inverse(q.scale_tril), p.scale_tril)) | ||
| term3 = _batch_mahalanobis(q.scale_tril, (q.loc - p.loc)) | ||
| return term1 + 0.5 * (term2 + term3 - p.event_shape[0]) | ||
| half_term1 = (_batch_diag(q._unbroadcasted_scale_tril).log().sum(-1) - |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if loc.dim() < 1: | ||
| loc = loc.unsqueeze(0) | ||
| event_shape = torch.Size(loc.shape[-1:]) | ||
| raise ValueError("loc must be at least one-dimensional.") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| def __init__(self, loc, scale_factor, scale_diag, validate_args=None): | ||
| if loc.dim() < 1: | ||
| raise ValueError("loc must be at least one-dimensional.") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| 'Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n), | ||
| 'Expected (from KL MultivariateNormal): {}'.format(expected), | ||
| 'Actual (analytic): {}'.format(actual_full_lowrank), | ||
| ])) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| b = 7 # Number of batches | ||
| loc = [torch.randn(b, 3) for _ in range(0, 2)] | ||
| scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)), | ||
| transform_to(constraints.lower_cholesky)(torch.randn(3, 3))] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@fritzo Currently, the implementation for Today, I checked again the slack's distributions room and found that @stepelu has suggested a very nice idea to reparameterize: take two random variables eps1, eps2 and compute W.eps1 + D^(1/2).eps2. It is a "whoa" moment for me. Sampling this way is pretty much fast! I will make changes corresponding to that. @fritzo @apaszke In addition, I checked again tensorflow's MultivariateNormalDiagPlusLowRank. It is different from us: their low rank version is for |
Nice! That sounds very similar to our pyro.contrib.gp.kernels.Coregionalize sampler.
I agree, those look like reasonable names. Alternatively maybe we should use |
fritzo
left a comment
There was a problem hiding this comment.
LGTM after fixing the .rsample() bug and adding a test.
I've only lightly reviewed the linear algebra, but your tests look thorough. Can you please confirm that every bit if math is exercised by tests?
| shape = self._extended_shape(sample_shape) | ||
| eps_W = self.loc.new_empty(shape[:-1] + (self.cov_factor.size(-1),)).normal_() | ||
| eps_D = self.loc.new_empty(shape).normal_() | ||
| return self.loc + _batch_mv(self.cov_factor, eps_W) + self.cov_diag * eps_D |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@fritzo I have fixed the bug and revert examples to the old version (which failed for this bug). Thanks a lot! |
|
@fehiepsi Why was the |
| 'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})' | ||
| .format(mean, cov_factor, cov_diag), multivariate=True) | ||
|
|
||
| def test_lowrank_multivariate_normal_properties(self): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
fritzo
left a comment
There was a problem hiding this comment.
LGTM, thanks for addressing all my review comments.
|
@pytorchbot retest this please |
|
@ssnl retest this please doesn't apply to Travis :) Go to the Travis UI and just click the "retry" button (I have done so) |
|
@ezyang Good to know. Thanks! :) |
|
caffe2 biuld failure looks spurious |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This pull request implements low rank multivariate normal distribution where the covariance matrix has the from `W @ W.T + D`. Here D is a diagonal matrix, W has shape n x m where m << n. It used "matrix determinant lemma" and "Woodbury matrix identity" to save computational cost. During the way, I also revise MultivariateNormal distribution a bit. Here are other changes: + `torch.trtrs` works with cuda tensor. So I tried to use it instead of `torch.inverse`. + Use `torch.matmul` instead of `torch.bmm` in `_batch_mv`. The former is faster and simpler. + Use `torch.diagonal` for `_batch_diag` + Reimplement `_batch_mahalanobis` based on `_batch_trtrs_lower`. + Use trtrs to compute term2 of KL. + `variance` relies on `scale_tril` instead of `covariance_matrix` TODO: - [x] Resolve the fail at `_gradcheck_log_prob` - [x] Add test for KL cc fritzo stepelu apaszke Pull Request resolved: pytorch#8635 Differential Revision: D8951893 Pulled By: ezyang fbshipit-source-id: 488ee3db6071150c33a1fb6624f3cfd9b52760c3
Summary: This pull request implements low rank multivariate normal distribution where the covariance matrix has the from `W @ W.T + D`. Here D is a diagonal matrix, W has shape n x m where m << n. It used "matrix determinant lemma" and "Woodbury matrix identity" to save computational cost. During the way, I also revise MultivariateNormal distribution a bit. Here are other changes: + `torch.trtrs` works with cuda tensor. So I tried to use it instead of `torch.inverse`. + Use `torch.matmul` instead of `torch.bmm` in `_batch_mv`. The former is faster and simpler. + Use `torch.diagonal` for `_batch_diag` + Reimplement `_batch_mahalanobis` based on `_batch_trtrs_lower`. + Use trtrs to compute term2 of KL. + `variance` relies on `scale_tril` instead of `covariance_matrix` TODO: - [x] Resolve the fail at `_gradcheck_log_prob` - [x] Add test for KL cc fritzo stepelu apaszke Pull Request resolved: pytorch#8635 Differential Revision: D8951893 Pulled By: ezyang fbshipit-source-id: 488ee3db6071150c33a1fb6624f3cfd9b52760c3
This pull request implements low rank multivariate normal distribution where the covariance matrix has the from
W @ W.T + D. Here D is a diagonal matrix, W has shape n x m where m << n. It used "matrix determinant lemma" and "Woodbury matrix identity" to save computational cost.During the way, I also revise MultivariateNormal distribution a bit. Here are other changes:
torch.trtrsworks with cuda tensor. So I tried to use it instead oftorch.inverse.torch.matmulinstead oftorch.bmmin_batch_mv. The former is faster and simpler.torch.diagonalfor_batch_diag_batch_mahalanobisbased on_batch_trtrs_lower.variancerelies onscale_trilinstead ofcovariance_matrixTODO:
_gradcheck_log_probcc @fritzo @stepelu @apaszke