Make iarange contexts reusable#863
Conversation
| # Check for incorrect iarange placement on the right of max_iarange_nesting. | ||
| for actual_size, expected_size in zip_longest(reversed(actual_shape), reversed(expected_shape), fillvalue=1): | ||
| if expected_size != -1 and actual_size not in (1, expected_size): | ||
| if expected_size != -1 and expected_size != actual_size: |
There was a problem hiding this comment.
This one-line change motivates the entire PR:
- if ... and actual_size not in (1, expected_size) # allow broadcasting
+ if ... and actual_size != expected_size # disallow broadcastingWith this new stricter requirement, we can more precisely track each sample statement's iaranges, and thereby support correct scaling (from subsampling) and parallel enumeration. But to keep Pyro's current flexibility despite this stricter requirement, we need iaranges to be reusable.
martinjankowiak
left a comment
There was a problem hiding this comment.
generally lgtm but someone should check some of the context logic and the (minimal) poutine changes
i'd also like to discuss this a bit before merging
|
this seems like a nice extension. having to be explicit about batch dimension for a context adds more work for the user, but i think it might actually be clearer in the long run than having that be magic? |
eb8680
left a comment
There was a problem hiding this comment.
The poutine changes seem fine.
As an aside, it seems like pyro/util.py and pyro/__init__.py are getting a little unwieldy. I'll push a quick PR today that splits them up.
| self._wrapped = am_i_wrapped() | ||
| if self._wrapped: | ||
| dim = 'auto' if self.dim is None else self.dim | ||
| self._scale_poutine = poutine.scale(None, self.size / self.subsample_size) |
There was a problem hiding this comment.
You can change scale to ScaleMessenger too
| raise ValueError('Expected dim < 0 to index from the right, actual {}'.format(dim)) | ||
| self.name = name | ||
| self.dim = dim | ||
| self.size, self.subsample_size, self.subsample = _subsample(name, size, subsample_size, subsample, use_cuda) |
There was a problem hiding this comment.
Is it always the case that we'll want an iarange instance to use the same subsampled indices? I'm pretty sure it is, just want to double-check.
There was a problem hiding this comment.
Yes, we want each iarange to fix a single subsample for each trace, so that subsampling agrees among tensors. For example, consider adding noise = x_noise + y_noise + xy_noise below:
def model(image):
x_axis = iarange('outer', 320, dim=-1, subsample_size=10)
y_axis = iarange('outer', 200, dim=-2, subsample_size=10)
with x_axis:
x_noise = sample("x_noise", Normal(mu, sigma).reshape([10]))
with y_axis:
y_noise = sample("y_noise", Normal(mu, sigma).reshape([10, 1]))
with x_axis as xs, y_axis as ys:
xy_noise = sample("xy_noise", Normal(mu, sigma).reshape([10, 10]))
noise = x_noise + y_noise + xy_noise # <----- must agree
pyro.observe("obs", dist.Normal(noise, 0), data=image[ys, xs])These already broadcast correctly, but we need the subsampling to agree so that the x agrees between x_noise[x] and xy_noise[:, x].
|
|
||
| CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "counter", "vectorized", "size"]) | ||
|
|
||
| class CondIndepStackFrame(namedtuple("CondIndepStackFrame", ["name", "dim", "size", "counter"])): |
Great idea, I agree that |
|
@martinjankowiak you can merge this when you're ready |
|
i'd like to discuss some of the consequences of this PR w/ @fritzo before doing so |
This enhances
iarangeto be a reusable context manager.Why?
Reusable
iarangecontexts are needed to support overlapping but non-nested independence. For example, consider an image noise model with independent noise components for each row, column, and pixel:This corresponds to non-nested the plate diagram
This PR also fixes a subsampling bug whereby variables such as
y_noiseused to be scaled by bothx_axisandy_axissubsampling, but henceforth will be scaled by onlyy_axissubsampling.How?
This changes
iarangeto implement the full context manager interface. We now need a newdimarg to keep track of which tensor dim will be used later on (since this is not know at time of construction).This functionality allows us to disallow broadcasting of log_prob within
iarangecontexts. Henceforth you can specify for each sampled tensor exactly the independence dims of that tensor.Tested
check_site_shapetests/infer/test_enum.pytests/infer/test_valid_models.pytests/infer/test_valid_models.pyCloses #370 (although @martinjankowiak's #780 did most of the work)