Skip to content

[distributions] Low rank multivariate normal#8635

Closed
fehiepsi wants to merge 21 commits intopytorch:masterfrom
fehiepsi:lowrank
Closed

[distributions] Low rank multivariate normal#8635
fehiepsi wants to merge 21 commits intopytorch:masterfrom
fehiepsi:lowrank

Conversation

@fehiepsi
Copy link
Contributor

@fehiepsi fehiepsi commented Jun 19, 2018

This pull request implements low rank multivariate normal distribution where the covariance matrix has the from W @ W.T + D. Here D is a diagonal matrix, W has shape n x m where m << n. It used "matrix determinant lemma" and "Woodbury matrix identity" to save computational cost.

During the way, I also revise MultivariateNormal distribution a bit. Here are other changes:

  • torch.trtrs works with cuda tensor. So I tried to use it instead of torch.inverse.
  • Use torch.matmul instead of torch.bmm in _batch_mv. The former is faster and simpler.
  • Use torch.diagonal for _batch_diag
  • Reimplement _batch_mahalanobis based on _batch_trtrs_lower.
  • Use trtrs to compute term2 of KL.
  • variance relies on scale_tril instead of covariance_matrix

TODO:

  • Resolve the fail at _gradcheck_log_prob
  • Add test for KL

cc @fritzo @stepelu @apaszke

loc_shape = batch_shape + event_shape
self.loc = loc.expand(loc_shape)
self.scale_factor = scale_factor.expand(loc_shape + scale_factor.shape[-1:])
self.scale_diag = scale_diag.expand(loc_shape)

This comment was marked as off-topic.

self.covariance_matrix = _batch_inverse(precision_matrix)
batch_shape = _get_batch_shape(precision_matrix, loc)
self.precision_matrix = precision_matrix.expand(batch_shape + event_shape + event_shape)
self.covariance_matrix = _batch_inverse(precision_matrix)

This comment was marked as off-topic.

batch_shape = _get_batch_shape(L, x)
x = x.expand(batch_shape + (n,))
L = L.expand(batch_shape + (n, n))
return _batch_trtrs_lower(x.unsqueeze(-1), L).squeeze(-1).pow(2).sum(-1)

This comment was marked as off-topic.

@fehiepsi
Copy link
Contributor Author

@fritzo

Details ============================= test session starts ============================== platform linux -- Python 3.5.5, pytest-3.6.0, py-1.5.3, pluggy-0.6.0 -- /home/fehiepsi/miniconda3/envs/pyro/bin/python cachedir: .pytest_cache benchmark: 3.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=5.00us max_time=1.00s calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: /home/fehiepsi/pytorch, inifile: plugins: xdist-1.22.2, forked-0.2, cov-2.5.1, benchmark-3.0.0, nbval-0.9.1 collecting ... collected 166 items / 155 deselected

test/test_distributions.py::TestDistributions::test_lowrank_multivariate_normal_log_prob PASSED [ 9%]
test/test_distributions.py::TestDistributions::test_lowrank_multivariate_normal_properties PASSED [ 18%]
test/test_distributions.py::TestDistributions::test_lowrank_multivariate_normal_sample PASSED [ 27%]
test/test_distributions.py::TestDistributions::test_lowrank_multivariate_normal_shape PASSED [ 36%]
test/test_distributions.py::TestDistributions::test_multivariate_normal_log_prob PASSED [ 45%]
test/test_distributions.py::TestDistributions::test_multivariate_normal_properties PASSED [ 54%]
test/test_distributions.py::TestDistributions::test_multivariate_normal_sample PASSED [ 63%]
test/test_distributions.py::TestDistributions::test_multivariate_normal_shape FAILED [ 72%]
test/test_distributions.py::TestRsample::test_dirichlet_multivariate PASSED [ 81%]
test/test_distributions.py::TestKL::test_kl_multivariate_normal PASSED [ 90%]
test/test_distributions.py::TestKL::test_kl_multivariate_normal_batched PASSED [100%]

=================================== FAILURES ===================================
_______________ TestDistributions.test_multivariate_normal_shape _______________

self = <test_distributions.TestDistributions testMethod=test_multivariate_normal_shape>

def test_multivariate_normal_shape(self):
    mean = torch.randn(5, 3, requires_grad=True)
    mean_no_batch = torch.randn(3, requires_grad=True)
    mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True)

    # construct PSD covariance
    tmp = torch.randn(3, 10)
    cov = torch.tensor(torch.matmul(tmp, tmp.t()) / tmp.shape[-1], requires_grad=True)
    prec = torch.tensor(cov.inverse(), requires_grad=True)
    scale_tril = torch.tensor(torch.potrf(cov, upper=False), requires_grad=True)

    # construct batch of PSD covariances
    tmp = torch.randn(6, 5, 3, 10)
    cov_batched = torch.tensor((tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1), requires_grad=True)
    prec_batched = [C.inverse() for C in cov_batched.view((-1, 3, 3))]
    prec_batched = torch.stack(prec_batched).view(cov_batched.shape)
    scale_tril_batched = [torch.potrf(C, upper=False) for C in cov_batched.view((-1, 3, 3))]
    scale_tril_batched = torch.stack(scale_tril_batched).view(cov_batched.shape)

    # ensure that sample, batch, event shapes all handled correctly
    self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
    self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
    self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3))
    self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
    self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3))
    self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3))
    self.assertEqual(MultivariateNormal(mean, cov).sample((2, 7)).size(), (2, 7, 5, 3))
    self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2, 7)).size(), (2, 7, 3))
    self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, 3))
    self.assertEqual(MultivariateNormal(mean, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
    self.assertEqual(MultivariateNormal(mean_no_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
    self.assertEqual(MultivariateNormal(mean_multi_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
    self.assertEqual(MultivariateNormal(mean, precision_matrix=prec).sample((2, 7)).size(), (2, 7, 5, 3))
    self.assertEqual(MultivariateNormal(mean, precision_matrix=prec_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
    self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2, 7)).size(), (2, 7, 5, 3))
    self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))

    # check gradients
  self._gradcheck_log_prob(MultivariateNormal, (mean, cov))

test/test_distributions.py:1604:


test/test_distributions.py:620: in _gradcheck_log_prob
gradcheck(apply_fn, ctor_params, raise_exception=True)
torch/autograd/gradcheck.py:192: in gradcheck
'numerical:%s\nanalytical:%s\n' % (i, j, n, a))


msg = 'Jacobian mismatch for output 0 with respect to input 1,\nnumerical:tensor([[-0.8790, -0.8150, -0.7580, -0.5903, -0.87... [ -3.5956, -12.2547, -3.2756, -4.5524, 2.5408],\n [ 2.0588, 12.0112, 1.2863, -0.1552, -3.4333]])\n'

def fail_test(msg):
    if raise_exception:
      raise RuntimeError(msg)

E RuntimeError: Jacobian mismatch for output 0 with respect to input 1,
E numerical:tensor([[-0.8790, -0.8150, -0.7580, -0.5903, -0.8705],
E [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
E [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
E [ 0.7389, 1.2535, 0.3848, -0.1361, 0.8325],
E [-0.6777, -0.3120, -0.6621, -0.2594, -1.0205],
E [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
E [ 0.0148, -1.0363, 0.5786, 0.7754, -0.1522],
E [-0.7191, -2.4509, -0.6551, -0.9105, 0.5082],
E [ 0.4118, 2.4022, 0.2573, -0.0310, -0.6867]])
E analytical:tensor([[ -4.3950, -4.0751, -3.7900, -2.9516, -4.3525],
E [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
E [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
E [ 3.6945, 6.2677, 1.9241, -0.6804, 4.1626],
E [ -3.3884, -1.5600, -3.3107, -1.2972, -5.1025],
E [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
E [ 0.0740, -5.1813, 2.8930, 3.8770, -0.7608],
E [ -3.5956, -12.2547, -3.2756, -4.5524, 2.5408],
E [ 2.0588, 12.0112, 1.2863, -0.1552, -3.4333]])

torch/autograd/gradcheck.py:172: RuntimeError
============= 1 failed, 10 passed, 155 deselected in 3.91 seconds ==============

@ssnl
Copy link
Collaborator

ssnl commented Jun 19, 2018

The gradient seems exactly 5x off.

@fehiepsi
Copy link
Contributor Author

I think that I have found the bug at #8649



def _batch_capacitance_tril(W, D):
r"""

This comment was marked as off-topic.

@vishwakftw
Copy link
Contributor

vishwakftw commented Jun 19, 2018

Although logically, this is supposed to be faster, I am interested in knowing what kind of speedup is obtained in the calculation of the Mahalanobis distance and scale_tril. Is it possible for you to get some numbers?

@fehiepsi
Copy link
Contributor Author

@vishwakftw Here is some comparison https://gist.github.com/fehiepsi/80810668bae0fbaa926e5b6319c1ecd3

Using trtrs is much faster than inverse when the size of matrix is large.

@vishwakftw
Copy link
Contributor

The timings are amazing!! Thanks!!

@ssnl
Copy link
Collaborator

ssnl commented Jun 20, 2018

@fehiepsi For GPU, you should always torch.cuda.synchronize() before measuring time.

@vishwakftw
Copy link
Contributor

Regarding the gradcheck tests, you could modify the _gradcheck_log_prob function to convert the ctor_params to double. Maybe that will help.

@ssnl
Copy link
Collaborator

ssnl commented Jun 20, 2018

@vishwakftw Since the gradients are exactly 5x off, I doubt it is a precision issue.

@vishwakftw
Copy link
Contributor

Ah yes, sorry - missed that detail above.

@fehiepsi
Copy link
Contributor Author

About the gradcheck which is 5x off, it seems that we have caught the issue at #8649 and @ssnl will take a look for it.

I am wondering about two other issues. The first is if scale_factor and scale_diag are good names. In tensorflow, they used scale_perturb_factor and scale_diag. factor term comes from "factor loading matrix" in factor analysis. It is also called components in PCA. To me, scale_factor is a good name but I am open to other choices.

Another issue is when to do "broadcasting". Different from other distributions, it seems that it is better to avoid using .expand in the first place. For example, computing scale_tril from expanded_covariance_matrix is less efficient than computing scale_tril from covariance_matrix then expand. So I think that depending on "whatever" scale input (cov, tril, precision) is, we will calculate _scale_tril_unexpand in the init method. Other properties will be expanded from the corresponding calculation with this _scale_tril_unexpand. This saves both memory/computation and makes properties of MVN consistency with other distributions.

Regarding "broadcasting", to solve Linv_y, I broadcast L to match the batch shape of y, then use trtrs to solve. This might be inefficient when L and y have different batch shape. Assume that batch_shape of the distribution is 9, y has shape 10 x 9 x 8, and L has shape 3 x 8 x 8, then it is better to convert L to 9 x 8 x 8, y to 9 x 80, and then apply batch_trtrs with L and y, then reshape the result. I am planning to implement that way after the gradcheck issue is fixed but wondering if there is a better workaround. FYI, we don't face this issue if we take the inverse of L first and then multiply because .matmul will automatically do "broadcasting".

@fritzo I think that you will have better ideas than me on these issues. Could you please give some thoughts on them?

@fritzo
Copy link
Collaborator

fritzo commented Jun 20, 2018

@fehiepsi The names seem fine; it would be nice to match tensorflow so that code is more portable.

Re: broadcasting, I've faced similar issues and haven't yet found a good solution. If you find a good pattern here, we may want to copy it elsewhere, e.g. in computing a batch of VonMises normalizing factors when scale is broadcast.

Here's a proposal:

  1. in the __init__ method, save the original unbroadcast covariance as self._unbroadcasted_covariance_matrix
  2. in the .scale_tril lazy property, use self._unbroadcasted_covariance_matrix and then apply broadcast_all or similar logic after computing the cholesky decomposition, and return this broadcasted result from the lazy property

We may additionally want to store self._unbroadcasted_scale_tril for other computations.

scale_diag_sqrt_unsqueeze = self.scale_diag.sqrt().unsqueeze(-1)
Dinvsqrt_W = self.scale_factor / scale_diag_sqrt_unsqueeze
K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K

This comment was marked as off-topic.


x = dist1.sample((10,))
expected = ref_dist.logpdf(x.numpy())
print(dist1.log_prob(x), MultivariateNormal(mean, cov).log_prob(x))

This comment was marked as off-topic.

Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests look great, I'll make one more pass over the math...


Args:
loc (Tensor): mean of the distribution
scale_factor (Tensor): factor part of low-rank form of covariance matrix

This comment was marked as off-topic.

This comment was marked as off-topic.

scale_diag (Tensor): diagonal part of low-rank form of covariance matrix

Note:
The computation for determinant and inverse of covariance matrix is saved when

This comment was marked as off-topic.


def __init__(self, loc, scale_factor, scale_diag, validate_args=None):
if loc.dim() < 1:
loc = loc.unsqueeze(0)

This comment was marked as off-topic.

This comment was marked as off-topic.

"""
n = bvec.size(-1)
flat_bvec = bvec.reshape(-1, n)
flat_bmat = torch.stack([v.diag() for v in flat_bvec])

This comment was marked as off-topic.

This comment was marked as off-topic.

the log determinant.
"""
if capacitance_tril is None:
capacitance_tril = _batch_capacitance_tril(W, D)

This comment was marked as off-topic.

Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
"""
if capacitance_tril is None:
capacitance_tril = _batch_capacitance_tril(W, D)

This comment was marked as off-topic.

'scale_diag': torch.tensor([2.0, 0.25], requires_grad=True),
},
{
'loc': torch.randn(2, 3, requires_grad=True),

This comment was marked as off-topic.

This comment was marked as off-topic.

scale_diag[i].diag() for i in range(0, 2)]
p = LowRankMultivariateNormal(loc[0], scale_factor[0], scale_diag[0])
q = LowRankMultivariateNormal(loc[1], scale_factor[1], scale_diag[1])
actual = kl_divergence(p, q)

This comment was marked as off-topic.

This comment was marked as off-topic.

term2 = _batch_trace_XXT(torch.matmul(_batch_inverse(q.scale_tril), p.scale_tril))
term3 = _batch_mahalanobis(q.scale_tril, (q.loc - p.loc))
return term1 + 0.5 * (term2 + term3 - p.event_shape[0])
half_term1 = (_batch_diag(q._unbroadcasted_scale_tril).log().sum(-1) -

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor Author

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo I have addressed all your comments. I point out changes from the last commit below for your convenience. If you have any else comment, please let me know. Thanks a lot!

if loc.dim() < 1:
loc = loc.unsqueeze(0)
event_shape = torch.Size(loc.shape[-1:])
raise ValueError("loc must be at least one-dimensional.")

This comment was marked as off-topic.


def __init__(self, loc, scale_factor, scale_diag, validate_args=None):
if loc.dim() < 1:
raise ValueError("loc must be at least one-dimensional.")

This comment was marked as off-topic.

'Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n),
'Expected (from KL MultivariateNormal): {}'.format(expected),
'Actual (analytic): {}'.format(actual_full_lowrank),
]))

This comment was marked as off-topic.

b = 7 # Number of batches
loc = [torch.randn(b, 3) for _ in range(0, 2)]
scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)),
transform_to(constraints.lower_cholesky)(torch.randn(3, 3))]

This comment was marked as off-topic.

@fehiepsi
Copy link
Contributor Author

fehiepsi commented Jul 11, 2018

@fritzo Currently, the implementation for .log_prob is fast (O(NM^2)), but the implementation for .sample is slow (O(N^3)) because I just compute scale_tril of W @ W.t() + D, which is a matrix of size N x N. This has been in my mind for a while and I have thought that there is no solution for it.

Today, I checked again the slack's distributions room and found that @stepelu has suggested a very nice idea to reparameterize: take two random variables eps1, eps2 and compute W.eps1 + D^(1/2).eps2. It is a "whoa" moment for me. Sampling this way is pretty much fast! I will make changes corresponding to that.

@fritzo @apaszke In addition, I checked again tensorflow's MultivariateNormalDiagPlusLowRank. It is different from us: their low rank version is for scale, not covariance_matrix (which is scale @ scale.t()). So I wonder if we stick with the names: scale_factor and scale_diag or change it to something like components and diagonal. If we keep scale_factor and scale_diag, then people might be confused with the ones in tensorflow.

@fritzo
Copy link
Collaborator

fritzo commented Jul 11, 2018

@stepelu has suggested a very nice idea to reparameterize

Nice! That sounds very similar to our pyro.contrib.gp.kernels.Coregionalize sampler.

change [param names] to something like components and diagonal

I agree, those look like reasonable names. Alternatively maybe we should use scale_factor and cov_diag? Or cov_factor and cov_diag? Whatever you think is clearest.

Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after fixing the .rsample() bug and adding a test.

I've only lightly reviewed the linear algebra, but your tests look thorough. Can you please confirm that every bit if math is exercised by tests?

shape = self._extended_shape(sample_shape)
eps_W = self.loc.new_empty(shape[:-1] + (self.cov_factor.size(-1),)).normal_()
eps_D = self.loc.new_empty(shape).normal_()
return self.loc + _batch_mv(self.cov_factor, eps_W) + self.cov_diag * eps_D

This comment was marked as off-topic.

This comment was marked as off-topic.

@fehiepsi
Copy link
Contributor Author

@fritzo I have fixed the bug and revert examples to the old version (which failed for this bug). Thanks a lot!

@fritzo
Copy link
Collaborator

fritzo commented Jul 14, 2018

@fehiepsi Why was the .rsample bug not caught in the .mean and .variance tests? Are those tests missing?

'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})'
.format(mean, cov_factor, cov_diag), multivariate=True)

def test_lowrank_multivariate_normal_properties(self):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing all my review comments.

@fehiepsi
Copy link
Contributor Author

Thank you, @fritzo!

@apaszke Could you please help me retest this? Travis-ci has been interupted for some reason.

@ssnl
Copy link
Collaborator

ssnl commented Jul 17, 2018

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented Jul 17, 2018

@ssnl retest this please doesn't apply to Travis :) Go to the Travis UI and just click the "retry" button (I have done so)

@ssnl
Copy link
Collaborator

ssnl commented Jul 17, 2018

@ezyang Good to know. Thanks! :)

@fehiepsi
Copy link
Contributor Author

@ssnl @ezyang we have a failure at caffe2 build; should we retest or this is ready to merge?

@ezyang
Copy link
Contributor

ezyang commented Jul 23, 2018

caffe2 biuld failure looks spurious

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

jramseyer pushed a commit to jramseyer/pytorch that referenced this pull request Jul 30, 2018
Summary:
This pull request implements low rank multivariate normal distribution where the covariance matrix has the from `W @ W.T + D`. Here D is a diagonal matrix, W has shape n x m where m << n. It used "matrix determinant lemma" and "Woodbury matrix identity" to save computational cost.

During the way, I also revise MultivariateNormal distribution a bit. Here are other changes:
+ `torch.trtrs` works with cuda tensor. So I tried to use it instead of `torch.inverse`.
+ Use `torch.matmul` instead of `torch.bmm` in `_batch_mv`. The former is faster and simpler.
+ Use `torch.diagonal` for `_batch_diag`
+ Reimplement `_batch_mahalanobis` based on `_batch_trtrs_lower`.
+ Use trtrs to compute term2 of KL.
+ `variance` relies on `scale_tril` instead of `covariance_matrix`

TODO:
- [x] Resolve the fail at `_gradcheck_log_prob`
- [x] Add test for KL

cc fritzo stepelu apaszke
Pull Request resolved: pytorch#8635

Differential Revision: D8951893

Pulled By: ezyang

fbshipit-source-id: 488ee3db6071150c33a1fb6624f3cfd9b52760c3
goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
This pull request implements low rank multivariate normal distribution where the covariance matrix has the from `W @ W.T + D`. Here D is a diagonal matrix, W has shape n x m where m << n. It used "matrix determinant lemma" and "Woodbury matrix identity" to save computational cost.

During the way, I also revise MultivariateNormal distribution a bit. Here are other changes:
+ `torch.trtrs` works with cuda tensor. So I tried to use it instead of `torch.inverse`.
+ Use `torch.matmul` instead of `torch.bmm` in `_batch_mv`. The former is faster and simpler.
+ Use `torch.diagonal` for `_batch_diag`
+ Reimplement `_batch_mahalanobis` based on `_batch_trtrs_lower`.
+ Use trtrs to compute term2 of KL.
+ `variance` relies on `scale_tril` instead of `covariance_matrix`

TODO:
- [x] Resolve the fail at `_gradcheck_log_prob`
- [x] Add test for KL

cc fritzo stepelu apaszke
Pull Request resolved: pytorch#8635

Differential Revision: D8951893

Pulled By: ezyang

fbshipit-source-id: 488ee3db6071150c33a1fb6624f3cfd9b52760c3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants