Skip to content

working towards nested iarange in tracegraph_elbo#780

Merged
fritzo merged 30 commits intodevfrom
nested-iarange
Feb 27, 2018
Merged

working towards nested iarange in tracegraph_elbo#780
fritzo merged 30 commits intodevfrom
nested-iarange

Conversation

@martinjankowiak
Copy link
Copy Markdown
Collaborator

@martinjankowiak martinjankowiak commented Feb 15, 2018

add nested iarange support to tracegraph_elbo. clean-up to follow in #816

Copy link
Copy Markdown
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests in test_compute_downstream_costs.py!


import warnings

from operator import itemgetter
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this should be in the builtins section above. I recommend setting up isort integration in vim.

Comment thread pyro/infer/tracegraph_elbo.py Outdated
topo_sort_guide_nodes = list(reversed(list(networkx.topological_sort(guide_trace))))
topo_sort_guide_nodes = [x for x in topo_sort_guide_nodes
if guide_trace.nodes[x]["type"] == "sample"]
ordered_guide_nodes_dict = dict(list(zip(topo_sort_guide_nodes, list(range(len(topo_sort_guide_nodes))))))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe simplify to {n: i for i, n in enumerate(topo_sorted_guide_nodes)}?

Comment thread pyro/infer/tracegraph_elbo.py Outdated
if model_trace.nodes[x]["type"] == "sample"]

for node in all_nodes:
if printhappy:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove before merging

Comment thread pyro/infer/tracegraph_elbo.py Outdated
"Falling back to higher-variance gradient estimator. "
"Try to avoid these issues in your model and guide:\n{}".format("\n".join(
guide_vec_md_info["warnings"] | model_vec_md_info["warnings"])))
# guide_vec_md_condition = guide_vec_md_info['rao-blackwellization-condition']
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove before merging

Comment thread pyro/infer/util.py Outdated
for k in range(1, 1 + source.dim()):
if source.size(-k) > target.size(-k):
source = source.sum(-k, keepdim=True)
# XXX make this more efficient
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment, this is fine

Comment thread pyro/poutine/trace_poutine.py Outdated

# construct data structure consumed by tracegraph_kl_qp
if vectorized_map_data_info['rao-blackwellization-condition']:
if True or vectorized_map_data_info['rao-blackwellization-condition']:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove before submitting?

with pyro.iarange("iarange_triple1", 6) as ind_triple1:
with pyro.iarange("iarange_triple2", 7) as ind_triple2:
with pyro.iarange("iarange_triple3", 9) as ind_triple3:
pyro.sample("z0", dist.Bernoulli(p2).reshape(sample_shape=[len(ind_triple1),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be

sample_shape=[len(ind_triple3), len(ind_triple2), len(ind_triple1)]

This will be checked more strictly in my #806

Comment thread tests/infer/test_valid_models.py Outdated
assert_error(model, guide, trace_graph=trace_graph)


@pytest.mark.skip('TODO FIX ME')
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be xfail rather than skip

Comment thread tests/infer/test_valid_models.py Outdated


def test_nested_iarange_iarange_warning():
@pytest.mark.xfail(reason="https://github.com/uber/pyro/issues/370")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this still xfail? It passes on my branch in #806 (with some changes) Ditto for all tests marked xfail below.

@martinjankowiak
Copy link
Copy Markdown
Collaborator Author

@fritzo ready for merge?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants