Skip to content

Implemented Relaxed Distributions#113

Closed
rachtsingh wants to merge 7 commits intoprobtorch:masterfrom
rachtsingh:concrete
Closed

Implemented Relaxed Distributions#113
rachtsingh wants to merge 7 commits intoprobtorch:masterfrom
rachtsingh:concrete

Conversation

@rachtsingh
Copy link
Copy Markdown

@rachtsingh rachtsingh commented Jan 29, 2018

Not sure if this is the right name for this distribution (Concrete / GumbelSoftmax are other ideas), but this is what Tensorflow calls it. This PR uses the transforms machinery :)

I had to edit TransformedDistribution's log_prob method to take into account event_shape - this is probably not the right way to do it, but a quick first try that makes it work.

  • Implement RelaxedOneHotCategorical
  • Implement RelaxedBernoulli

cc @fritzo

Copy link
Copy Markdown

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice factorization into 2 classes!

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the docs and example sample?

Copy link
Copy Markdown

@fritzo fritzo Jan 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExpRelaxedCategorical seems ok too, since OneHot is kind of the only relay to relax.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Since Tensor and Variable are being merged, this might not be necessary.
  2. Prefer torch.is_tensor(), since torch.Tensor is an alias rather than a base class:
isinstance(torch.FloatTensor([0]), torch.Tensor)
# True
isinstance(torch.DoubleTensor([0]), torch.Tensor)
# False
torch.Tensor is torch.FloatTensor
# True

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I'm hoping it's not necessary soon, but it definitely is right now because log_softmax takes Variables and raises when called on a Tensor. Can we leave the shim until the merge?
  2. Definitely, didn't realize this was the case.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document temperature parameter

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is __neg__ cheaper or more stable than .mul(-1)?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I call it via -(-(uniforms.log()).log())?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I just assume it's a tiny bit cheaper.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect either the transform to be ExpTransform().inv or the base distribution to be LogRelaxedOneHotCategorical. Is the latter a reasonable renaming, analogous to LogNormal?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I agree with LogRelaxedOneHotCategorical as well. However, TF uses the version I have because that's what the paper names it. My thinking is that it's probably more important just to document it and be consistent with the other interfaces, since no user will need to see it.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, let's keep it standard.

Copy link
Copy Markdown

@fritzo fritzo Jan 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some transforms should already know their event shape, in which case this .sum(-1) would sum too much. Also sometimes log_prob is just the number 0 here (e.g. for identity_transform). I think it would be safer to do one of:

  1. add an event_shape or extra_event_shape arg to pointwise transforms
  2. implemement an TensorTransform(base_transform, event_dim=1) to wrap ExpTransform
  3. make this hack a tiny bit more robust:
log_prob = ...sum of log_abs_det_jacobian_terms...
base_log_prob = self.base_dist.log_prob(y)  # knows the correct dim
if not isinstance(log_prob, numbers.Number):
    while log_prob.dim() > base_log_prob.dim():
        log_prob = log_prob.sum(-1)
log_prob += base_log_prob

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like solution 1. the best. I think fixing the hack as you did doesn't work because log_prob isn't a number most of the time - for example it can be [sample_shape x batch_shape x ...]. Let me rethink this and get back to you.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 We should also make constraints more aware of event dim. @tbrx is already doing this for multivariate normal in #52 and we'll need it e.g. for AffineOperatorTransform which transforms from constraints.real_vector to constraints.real_vector.

Copy link
Copy Markdown

@fritzo fritzo Jan 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking about this a bit, I think a 4th approach could be minimally invasive:

  • add a static event_dim attribute to each Transform (only about 4 lines of code diff)
  • update the log_prob accumulation loop in TransformedDistribution.log_prob()

I'll sketch this in a PR so just to discuss. EDIT here it is: #116

@rachtsingh
Copy link
Copy Markdown
Author

rachtsingh commented Jan 31, 2018

Ok, made changes based on comments, and added the RelaxedBernoulli distribution. No longer blocked, so ready for another review if necessary @fritzo.

@rachtsingh rachtsingh changed the title Implemented RelaxedOneHotCategorical Distribution Implemented Relaxed Distributions Feb 1, 2018
In some cases when there are two different versions of cudnn installed,
one under /usr/local/cuda and other under a virtual env such as conda or
under the main system path /usr/include, the compiler would pickup the
cudnn.h from the virtual env/system path first. This is because cmake
generates C_INCLUDES and CXX_INCLUDES flags with system include path
first. All this may lead to linking problems as described in Issue pytorch#4869

Fixes pytorch#4869
@rachtsingh
Copy link
Copy Markdown
Author

Added tests that:

  1. Rounding the RelaxedBernoulli distribution gives the corresponding Bernoulli distribution
  2. Taking the argmax of the RelaxedOneHotCategorical gives the corresponding Categorical distribution
  3. As the temperature becomes very large, the first consistently gives 0.5 and the latter gives equal values for each index.

zdevito and others added 4 commits February 5, 2018 10:43
* Remove addValues and use WithInsertPoint

* Use blocks to simplify differentiate

Using @ezyang's suggestion, this change uses a block rather than
staging annotations to represent the reverse pass. This allows us
to reuse the machinery to copy graphs/blocks to extract the
reverse pass concisely.

This also change the input order of Gradients df to:
   [output vjps][temporary vjps][captures]

In addition to being simpler to generate in this order, it also
will allow ExecutionPlan to append the captures onto the already-
existing input list of vjps that are given by the autograd,
rather than have to prepend them, which should be slightly cheaper.

* Enforce that input capture are before outputs

This changes the Gradient struct to enforce that input
captures appear before output captures in the capture list,
which makes it easier to use in ExecutionPlan.
…r. (pytorch#5003)

* Don't allow scalars where vectors are required in mv, addmv, ger, addr.

* Fix scalar_tensor_test for ger.

* Address review comments.

* Fix merge.
Once Variable and Tensor are merged the existing Variable test would
cause an infinite recursion. Instead, modify the Variables directly
inside a `no_grad()` block.
sspaddmm, mm for sparse tensors to come in another pr; they're a little more involved.
Copy link
Copy Markdown

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread test/test_distributions.py Outdated
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test!

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, have you considered implementing this as a TransformedDistribution on top of Gumbel? That probably wouldn't be as good as your implementation here, but I'm curious why TransformedDistribution wouldn't work, and how we could improve it to be suitable in this context. For example, could we define this as a TranformedDistribution(Gumbel(...), BoltzmannTransform()) and implement a suitable BoltzmannTransform.log_abs_det_jacobian()?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can definitely implement this via a TransformedDistribution of Gumbel + BoltzmannTransform; I just didn't see the BoltzmannTransform. I'll make an issue about implementing the log_abs_det_jacobian and then revisit it when that's been solved?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will be tricky, but let's move discussion to that issue. Feel free to send this PR upstream before the refactoring.

li-roy and others added 2 commits February 5, 2018 12:28
* add reduce=True argument to MultiLabelMarginLoss

* Fix lint

* Addressed comments

* Remove unneeded syncthreads calls
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants