Skip to content

Optional expand=True kwarg in distribution.enumerate_support#11231

Closed
neerajprad wants to merge 4 commits intopytorch:masterfrom
neerajprad:enumerate-expand
Closed

Optional expand=True kwarg in distribution.enumerate_support#11231
neerajprad wants to merge 4 commits intopytorch:masterfrom
neerajprad:enumerate-expand

Conversation

@neerajprad
Copy link
Contributor

@neerajprad neerajprad commented Sep 4, 2018

This adds an optional expand=True kwarg to the distribution.expand_support() method, to get a distribution's support without expanding the values over the distribution's batch_shape.

  • The default expand=True preserves the current behavior, whereas expand=False collapses the batch dimensions.

e.g.

In [47]: d = dist.OneHotCategorical(torch.ones(3, 5) * 0.5)

In [48]: d.batch_shape
Out[48]: torch.Size([3])

In [49]: d.enumerate_support()
Out[49]:
tensor([[[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]]])

In [50]: d.enumerate_support().shape
Out[50]: torch.Size([5, 3, 5])

In [51]: d.enumerate_support(expand=False)
Out[51]:
tensor([[[1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.]]])

In [52]: d.enumerate_support(expand=False).shape
Out[52]: torch.Size([5, 1, 5])

Motivation:

  • Currently enumerate_support builds up tensors of size support + batch_shape + event_shape, but the values are repeated over the batch_shape (adding little in the way of information). This can lead to expensive matrix operations over large tensors when batch_shape is large (see, example above), often leading to OOM issues. We use expand=False in Pyro for message passing inference. e.g. when enumerating over the state space in a Hidden Markov Model. This creates sparse tensors that capture the markov dependence, and allows for the possibility of using optimized matrix operations over these sparse tensors. expand=True, on the other hand, will create tensors that scale exponentially in size with the length of the Markov chain.
  • We have been using this in our patch of torch.distributions in Pyro. The interface has been stable, and it is already being used in a few Pyro algorithms. We think that this is more broadly applicable and will be of interest to the larger distributions community.

cc. @apaszke, @fritzo, @alicanb

@neerajprad neerajprad changed the title Optional expand=True kward to distribution.enumerate_support Optional expand=True kwarg in distribution.enumerate_support Sep 4, 2018
Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for moving this upstream!

@neerajprad
Copy link
Contributor Author

For some reason, the Test and Push jobs are all failing with the same error in test_jit, but it seems unrelated to the PR.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
…#11231)

Summary:
This adds an optional `expand=True` kwarg to the `distribution.expand_support()` method, to get a distribution's support without expanding the values over the distribution's `batch_shape`.
 - The default `expand=True` preserves the current behavior, whereas `expand=False` collapses the batch dimensions.

e.g.
```python
In [47]: d = dist.OneHotCategorical(torch.ones(3, 5) * 0.5)

In [48]: d.batch_shape
Out[48]: torch.Size([3])

In [49]: d.enumerate_support()
Out[49]:
tensor([[[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]]])

In [50]: d.enumerate_support().shape
Out[50]: torch.Size([5, 3, 5])

In [51]: d.enumerate_support(expand=False)
Out[51]:
tensor([[[1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.]]])

In [52]: d.enumerate_support(expand=False).shape
Out[52]: torch.Size([5, 1, 5])
```

**Motivation:**
 - Currently `enumerate_support` builds up tensors of size `support + batch_shape + event_shape`, but the values are *repeated* over the `batch_shape` (adding little in the way of information). This can lead to expensive matrix operations over large tensors when `batch_shape` is large (see, example above), often leading to OOM issues. We use `expand=False` in Pyro for message passing inference. e.g. when enumerating over the state space in a Hidden Markov Model. This creates sparse tensors that capture the markov dependence, and allows for the possibility of using optimized matrix operations over these sparse tensors. `expand=True`, on the other hand, will create tensors that scale exponentially in size with the length of the Markov chain.
 - We have been using this in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py) of `torch.distributions` in Pyro. The interface has been stable, and it is already being used in a few Pyro algorithms. We think that this is more broadly applicable and will be of interest to the larger distributions community.

cc. apaszke, fritzo, alicanb
Pull Request resolved: pytorch#11231

Differential Revision: D9696290

Pulled By: soumith

fbshipit-source-id: c556f8ff374092e8366897ebe3f3b349538d9318
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants