Skip to content

Create a MockNormal class; drop instance .reparameterized #631

@fritzo

Description

@fritzo

Background

Currently .reparameterized is a class-level attribute but can be overridden by instance attributes. This override is only used in testing, where we make sure that TraceGraph_ELBO performs the same inference on a reparameterized and non-reparameterized Normal.

The problem

The instance attribute .reparameterized causes a few problems:

  1. It requires extra plumbing for all Distribution classes (they must pass *args,**kwargs to super(...).__init__())
  2. It is not supported by torch.distributions to which we are migrating
  3. It is not respected by RandomPrimitive which simply ignores the instance attribute and returns the class attribute. (This is necessary because RandomPrimitive only lazily constructs instances, and there may be no instance to examine).
  4. It complicates the torch.distribution wrapper layer. The wrapper dynamically chooses between pyro vs torch implementations (torch does not support fancy things like batch_size). Some torch distributions are reparameterized whereas their Pyro equivalents are not (Gamma, Beta, Dirichlet). Therefore RandomPrimitive cannot statically determine whether a distribution is reparameterized.

Proposed solution

We should eliminate the instance-level .reparameterized property. For the few cases where we need a non-reparameterized Normal in testing, we should simply implement a NonreparamNormal class (or even a Nonreparameterized class -> class transform).

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions