Skip to content

Make iarange contexts reusable#863

Merged
martinjankowiak merged 5 commits intodevfrom
iarange-reuse
Mar 8, 2018
Merged

Make iarange contexts reusable#863
martinjankowiak merged 5 commits intodevfrom
iarange-reuse

Conversation

@fritzo
Copy link
Copy Markdown
Member

@fritzo fritzo commented Mar 8, 2018

This enhances iarange to be a reusable context manager.

Why?

Reusable iarange contexts are needed to support overlapping but non-nested independence. For example, consider an image noise model with independent noise components for each row, column, and pixel:

x_axis = iarange('outer', 320, dim=-1)
y_axis = iarange('outer', 200, dim=-2)
with x_axis:
    x_noise = sample("x_noise", Normal(mu, sigma).reshape([320]))
with y_axis:
    y_noise = sample("y_noise", Normal(mu, sigma).reshape([200, 1]))
with x_axis, y_axis:
    xy_noise = sample("xy_noise", Normal(mu, sigma).reshape([200, 320]))

This corresponds to non-nested the plate diagram

+----------------+
| x_noise    320 |
|                |
| +--------------------------------+
| | xy_noise     |  y_noise    200 |                
| +--------------------------------+
+----------------+

This PR also fixes a subsampling bug whereby variables such as y_noise used to be scaled by both x_axis and y_axis subsampling, but henceforth will be scaled by only y_axis subsampling.

How?

This changes iarange to implement the full context manager interface. We now need a new dim arg to keep track of which tensor dim will be used later on (since this is not know at time of construction).

This functionality allows us to disallow broadcasting of log_prob within iarange contexts. Henceforth you can specify for each sampled tensor exactly the independence dims of that tensor.

Tested

  • added new shape checks to check_site_shape
  • added reuse examples to tests/infer/test_enum.py
  • added reuse examples to tests/infer/test_valid_models.py
  • added tests for broadcasting errors to tests/infer/test_valid_models.py

Closes #370 (although @martinjankowiak's #780 did most of the work)

Comment thread pyro/util.py
# Check for incorrect iarange placement on the right of max_iarange_nesting.
for actual_size, expected_size in zip_longest(reversed(actual_shape), reversed(expected_shape), fillvalue=1):
if expected_size != -1 and actual_size not in (1, expected_size):
if expected_size != -1 and expected_size != actual_size:
Copy link
Copy Markdown
Member Author

@fritzo fritzo Mar 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one-line change motivates the entire PR:

- if ... and actual_size not in (1, expected_size)  # allow broadcasting
+ if ... and actual_size != expected_size           # disallow broadcasting

With this new stricter requirement, we can more precisely track each sample statement's iaranges, and thereby support correct scaling (from subsampling) and parallel enumeration. But to keep Pyro's current flexibility despite this stricter requirement, we need iaranges to be reusable.

martinjankowiak
martinjankowiak previously approved these changes Mar 8, 2018
Copy link
Copy Markdown
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally lgtm but someone should check some of the context logic and the (minimal) poutine changes

i'd also like to discuss this a bit before merging

@fritzo fritzo requested a review from eb8680 March 8, 2018 17:25
@ngoodman
Copy link
Copy Markdown
Collaborator

ngoodman commented Mar 8, 2018

this seems like a nice extension.

having to be explicit about batch dimension for a context adds more work for the user, but i think it might actually be clearer in the long run than having that be magic?

eb8680
eb8680 previously approved these changes Mar 8, 2018
Copy link
Copy Markdown
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The poutine changes seem fine.

As an aside, it seems like pyro/util.py and pyro/__init__.py are getting a little unwieldy. I'll push a quick PR today that splits them up.

Comment thread pyro/__init__.py Outdated
self._wrapped = am_i_wrapped()
if self._wrapped:
dim = 'auto' if self.dim is None else self.dim
self._scale_poutine = poutine.scale(None, self.size / self.subsample_size)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can change scale to ScaleMessenger too

Comment thread pyro/__init__.py
raise ValueError('Expected dim < 0 to index from the right, actual {}'.format(dim))
self.name = name
self.dim = dim
self.size, self.subsample_size, self.subsample = _subsample(name, size, subsample_size, subsample, use_cuda)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always the case that we'll want an iarange instance to use the same subsampled indices? I'm pretty sure it is, just want to double-check.

Copy link
Copy Markdown
Member Author

@fritzo fritzo Mar 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we want each iarange to fix a single subsample for each trace, so that subsampling agrees among tensors. For example, consider adding noise = x_noise + y_noise + xy_noise below:

def model(image):
    x_axis = iarange('outer', 320, dim=-1, subsample_size=10)
    y_axis = iarange('outer', 200, dim=-2, subsample_size=10)
    with x_axis:
        x_noise = sample("x_noise", Normal(mu, sigma).reshape([10]))
    with y_axis:
        y_noise = sample("y_noise", Normal(mu, sigma).reshape([10, 1]))
    with x_axis as xs, y_axis as ys:
        xy_noise = sample("xy_noise", Normal(mu, sigma).reshape([10, 10]))
        noise = x_noise + y_noise + xy_noise    # <----- must agree
        pyro.observe("obs", dist.Normal(noise, 0), data=image[ys, xs])

These already broadcast correctly, but we need the subsampling to agree so that the x agrees between x_noise[x] and xy_noise[:, x].


CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "counter", "vectorized", "size"])

class CondIndepStackFrame(namedtuple("CondIndepStackFrame", ["name", "dim", "size", "counter"])):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very slick!

@fritzo fritzo dismissed stale reviews from eb8680 and martinjankowiak via 80b4bb4 March 8, 2018 19:07
@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Mar 8, 2018

I'll push a quick PR today that splits them up.

Great idea, I agree that pyro/__init__.py is getting unwieldy.

@eb8680
Copy link
Copy Markdown
Member

eb8680 commented Mar 8, 2018

@martinjankowiak you can merge this when you're ready

@martinjankowiak
Copy link
Copy Markdown
Collaborator

i'd like to discuss some of the consequences of this PR w/ @fritzo before doing so

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support irange/iarange nested inside iarange

4 participants