Optional expand=True kwarg in distribution.enumerate_support#11231
Closed
neerajprad wants to merge 4 commits intopytorch:masterfrom
Closed
Optional expand=True kwarg in distribution.enumerate_support#11231neerajprad wants to merge 4 commits intopytorch:masterfrom
neerajprad wants to merge 4 commits intopytorch:masterfrom
Conversation
fritzo
approved these changes
Sep 4, 2018
Collaborator
fritzo
left a comment
There was a problem hiding this comment.
Thanks for moving this upstream!
Contributor
Author
|
For some reason, the |
soumith
approved these changes
Sep 7, 2018
Contributor
facebook-github-bot
left a comment
There was a problem hiding this comment.
soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
PenghuiCheng
pushed a commit
to PenghuiCheng/pytorch
that referenced
this pull request
Sep 11, 2018
…#11231) Summary: This adds an optional `expand=True` kwarg to the `distribution.expand_support()` method, to get a distribution's support without expanding the values over the distribution's `batch_shape`. - The default `expand=True` preserves the current behavior, whereas `expand=False` collapses the batch dimensions. e.g. ```python In [47]: d = dist.OneHotCategorical(torch.ones(3, 5) * 0.5) In [48]: d.batch_shape Out[48]: torch.Size([3]) In [49]: d.enumerate_support() Out[49]: tensor([[[1., 0., 0., 0., 0.], [1., 0., 0., 0., 0.], [1., 0., 0., 0., 0.]], [[0., 1., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 1., 0., 0., 0.]], [[0., 0., 1., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 1., 0., 0.]], [[0., 0., 0., 1., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 1., 0.]], [[0., 0., 0., 0., 1.], [0., 0., 0., 0., 1.], [0., 0., 0., 0., 1.]]]) In [50]: d.enumerate_support().shape Out[50]: torch.Size([5, 3, 5]) In [51]: d.enumerate_support(expand=False) Out[51]: tensor([[[1., 0., 0., 0., 0.]], [[0., 1., 0., 0., 0.]], [[0., 0., 1., 0., 0.]], [[0., 0., 0., 1., 0.]], [[0., 0., 0., 0., 1.]]]) In [52]: d.enumerate_support(expand=False).shape Out[52]: torch.Size([5, 1, 5]) ``` **Motivation:** - Currently `enumerate_support` builds up tensors of size `support + batch_shape + event_shape`, but the values are *repeated* over the `batch_shape` (adding little in the way of information). This can lead to expensive matrix operations over large tensors when `batch_shape` is large (see, example above), often leading to OOM issues. We use `expand=False` in Pyro for message passing inference. e.g. when enumerating over the state space in a Hidden Markov Model. This creates sparse tensors that capture the markov dependence, and allows for the possibility of using optimized matrix operations over these sparse tensors. `expand=True`, on the other hand, will create tensors that scale exponentially in size with the length of the Markov chain. - We have been using this in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py) of `torch.distributions` in Pyro. The interface has been stable, and it is already being used in a few Pyro algorithms. We think that this is more broadly applicable and will be of interest to the larger distributions community. cc. apaszke, fritzo, alicanb Pull Request resolved: pytorch#11231 Differential Revision: D9696290 Pulled By: soumith fbshipit-source-id: c556f8ff374092e8366897ebe3f3b349538d9318
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This adds an optional
expand=Truekwarg to thedistribution.expand_support()method, to get a distribution's support without expanding the values over the distribution'sbatch_shape.expand=Truepreserves the current behavior, whereasexpand=Falsecollapses the batch dimensions.e.g.
Motivation:
enumerate_supportbuilds up tensors of sizesupport + batch_shape + event_shape, but the values are repeated over thebatch_shape(adding little in the way of information). This can lead to expensive matrix operations over large tensors whenbatch_shapeis large (see, example above), often leading to OOM issues. We useexpand=Falsein Pyro for message passing inference. e.g. when enumerating over the state space in a Hidden Markov Model. This creates sparse tensors that capture the markov dependence, and allows for the possibility of using optimized matrix operations over these sparse tensors.expand=True, on the other hand, will create tensors that scale exponentially in size with the length of the Markov chain.torch.distributionsin Pyro. The interface has been stable, and it is already being used in a few Pyro algorithms. We think that this is more broadly applicable and will be of interest to the larger distributions community.cc. @apaszke, @fritzo, @alicanb