Background
Currently .reparameterized is a class-level attribute but can be overridden by instance attributes. This override is only used in testing, where we make sure that TraceGraph_ELBO performs the same inference on a reparameterized and non-reparameterized Normal.
The problem
The instance attribute .reparameterized causes a few problems:
- It requires extra plumbing for all
Distribution classes (they must pass *args,**kwargs to super(...).__init__())
- It is not supported by
torch.distributions to which we are migrating
- It is not respected by
RandomPrimitive which simply ignores the instance attribute and returns the class attribute. (This is necessary because RandomPrimitive only lazily constructs instances, and there may be no instance to examine).
- It complicates the torch.distribution wrapper layer. The wrapper dynamically chooses between pyro vs torch implementations (torch does not support fancy things like
batch_size). Some torch distributions are reparameterized whereas their Pyro equivalents are not (Gamma, Beta, Dirichlet). Therefore RandomPrimitive cannot statically determine whether a distribution is reparameterized.
Proposed solution
We should eliminate the instance-level .reparameterized property. For the few cases where we need a non-reparameterized Normal in testing, we should simply implement a NonreparamNormal class (or even a Nonreparameterized class -> class transform).
Background
Currently
.reparameterizedis a class-level attribute but can be overridden by instance attributes. This override is only used in testing, where we make sure thatTraceGraph_ELBOperforms the same inference on a reparameterized and non-reparameterizedNormal.The problem
The instance attribute
.reparameterizedcauses a few problems:Distributionclasses (they must pass*args,**kwargstosuper(...).__init__())torch.distributionsto which we are migratingRandomPrimitivewhich simply ignores the instance attribute and returns the class attribute. (This is necessary becauseRandomPrimitiveonly lazily constructs instances, and there may be no instance to examine).batch_size). Some torch distributions are reparameterized whereas their Pyro equivalents are not (Gamma,Beta,Dirichlet). ThereforeRandomPrimitivecannot statically determine whether a distribution is reparameterized.Proposed solution
We should eliminate the instance-level
.reparameterizedproperty. For the few cases where we need a non-reparameterizedNormalin testing, we should simply implement aNonreparamNormalclass (or even aNonreparameterizedclass -> class transform).