Skip to content

Master Task - Update broadcasting and batching semantics for Pyro Distributions #638

@neerajprad

Description

@neerajprad

We have made some major changes to the torch.distributions API, and would like to make this available in Pyro in its full generality. However, the current Pyro shape semantics are incompatible with this new API.

Some of these incompatibilities are as follows:

  • The batch and sample dimensions in Pyro are conflated, which makes it confusing for the user to determine the shape of the distribution parameters to be supplied, i.e. it is common to provide identical distribution parameters only to impose a shape on the sample.
  • Any batching (i.e. sample dimensions) is restricted to a single dimension, and does not support arbitrary shaped sample sizes.
  • In the absence of scalar support, batch_log_pdf upcasts the resulting score by appending an additional (1,) rightmost dimension. In the case of PyTorch, we have tried to localize this behavior to only upcast when any dimension reduction results in a 0 sized tensor. This makes the migration to torch.Scalar simpler.

By the next PyTorch release, we would like to have this rich and updated API available to Pyro. @fritzo and I decided that it will simplify the transition if Pyro's dev branch only tries to maintain compatibility with PyTorch's master instead of the release version which it does now (refer to Phase 1.5).

The transition plan is as follows:

Phase 1 - Changes in PyTorch:

Phase 1.5 - Infra / internal changes to help with migration:

Can be done simultaneously with Phase 1.

  • Create a lightweight torch wheel that can be periodically updated to reflect PyTorch master. This will be useful for CI testing of Pyro dev against PyTorch master. - Refer to pytorch#4178, Have Pyro dev branch track PyTorch's master branch #670
  • Make tutorials testable - tutorials need to be written using a format that allows for easy and fast testing of API related changes.
  • Add versioning to tutorials - Given that some of our changes will not be backwards compatible, we should version our tutorials to make it easy for our users to refer to the relevant version depending on their version of Pyro install. (Ref. Add versioning to tutorials #639)

Phase 2 - Internal change to Distributions API:

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions