Implemented Relaxed Distributions#113
Conversation
There was a problem hiding this comment.
Could you update the docs and example sample?
There was a problem hiding this comment.
ExpRelaxedCategorical seems ok too, since OneHot is kind of the only relay to relax.
There was a problem hiding this comment.
- Since
TensorandVariableare being merged, this might not be necessary. - Prefer
torch.is_tensor(), sincetorch.Tensoris an alias rather than a base class:
isinstance(torch.FloatTensor([0]), torch.Tensor)
# True
isinstance(torch.DoubleTensor([0]), torch.Tensor)
# False
torch.Tensor is torch.FloatTensor
# TrueThere was a problem hiding this comment.
- I'm hoping it's not necessary soon, but it definitely is right now because
log_softmaxtakesVariablesand raises when called on aTensor. Can we leave the shim until the merge? - Definitely, didn't realize this was the case.
There was a problem hiding this comment.
Is __neg__ cheaper or more stable than .mul(-1)?
There was a problem hiding this comment.
Can I call it via -(-(uniforms.log()).log())?
There was a problem hiding this comment.
Yes, I just assume it's a tiny bit cheaper.
There was a problem hiding this comment.
I would expect either the transform to be ExpTransform().inv or the base distribution to be LogRelaxedOneHotCategorical. Is the latter a reasonable renaming, analogous to LogNormal?
There was a problem hiding this comment.
Right, I agree with LogRelaxedOneHotCategorical as well. However, TF uses the version I have because that's what the paper names it. My thinking is that it's probably more important just to document it and be consistent with the other interfaces, since no user will need to see it.
There was a problem hiding this comment.
Some transforms should already know their event shape, in which case this .sum(-1) would sum too much. Also sometimes log_prob is just the number 0 here (e.g. for identity_transform). I think it would be safer to do one of:
- add an
event_shapeorextra_event_shapearg to pointwise transforms - implemement an
TensorTransform(base_transform, event_dim=1)to wrapExpTransform - make this hack a tiny bit more robust:
log_prob = ...sum of log_abs_det_jacobian_terms...
base_log_prob = self.base_dist.log_prob(y) # knows the correct dim
if not isinstance(log_prob, numbers.Number):
while log_prob.dim() > base_log_prob.dim():
log_prob = log_prob.sum(-1)
log_prob += base_log_probThere was a problem hiding this comment.
I like solution 1. the best. I think fixing the hack as you did doesn't work because log_prob isn't a number most of the time - for example it can be [sample_shape x batch_shape x ...]. Let me rethink this and get back to you.
There was a problem hiding this comment.
After thinking about this a bit, I think a 4th approach could be minimally invasive:
- add a static
event_dimattribute to eachTransform(only about 4 lines of code diff) - update the
log_probaccumulation loop inTransformedDistribution.log_prob()
I'll sketch this in a PR so just to discuss. EDIT here it is: #116
9e64063 to
ca14e21
Compare
|
Ok, made changes based on comments, and added the |
In some cases when there are two different versions of cudnn installed, one under /usr/local/cuda and other under a virtual env such as conda or under the main system path /usr/include, the compiler would pickup the cudnn.h from the virtual env/system path first. This is because cmake generates C_INCLUDES and CXX_INCLUDES flags with system include path first. All this may lead to linking problems as described in Issue pytorch#4869 Fixes pytorch#4869
|
Added tests that:
|
* Remove addValues and use WithInsertPoint * Use blocks to simplify differentiate Using @ezyang's suggestion, this change uses a block rather than staging annotations to represent the reverse pass. This allows us to reuse the machinery to copy graphs/blocks to extract the reverse pass concisely. This also change the input order of Gradients df to: [output vjps][temporary vjps][captures] In addition to being simpler to generate in this order, it also will allow ExecutionPlan to append the captures onto the already- existing input list of vjps that are given by the autograd, rather than have to prepend them, which should be slightly cheaper. * Enforce that input capture are before outputs This changes the Gradient struct to enforce that input captures appear before output captures in the capture list, which makes it easier to use in ExecutionPlan.
…r. (pytorch#5003) * Don't allow scalars where vectors are required in mv, addmv, ger, addr. * Fix scalar_tensor_test for ger. * Address review comments. * Fix merge.
Once Variable and Tensor are merged the existing Variable test would cause an infinite recursion. Instead, modify the Variables directly inside a `no_grad()` block.
sspaddmm, mm for sparse tensors to come in another pr; they're a little more involved.
There was a problem hiding this comment.
Just curious, have you considered implementing this as a TransformedDistribution on top of Gumbel? That probably wouldn't be as good as your implementation here, but I'm curious why TransformedDistribution wouldn't work, and how we could improve it to be suitable in this context. For example, could we define this as a TranformedDistribution(Gumbel(...), BoltzmannTransform()) and implement a suitable BoltzmannTransform.log_abs_det_jacobian()?
There was a problem hiding this comment.
We can definitely implement this via a TransformedDistribution of Gumbel + BoltzmannTransform; I just didn't see the BoltzmannTransform. I'll make an issue about implementing the log_abs_det_jacobian and then revisit it when that's been solved?
There was a problem hiding this comment.
I think this will be tricky, but let's move discussion to that issue. Feel free to send this PR upstream before the refactoring.
* add reduce=True argument to MultiLabelMarginLoss * Fix lint * Addressed comments * Remove unneeded syncthreads calls
Not sure if this is the right name for this distribution (Concrete / GumbelSoftmax are other ideas), but this is what Tensorflow calls it. This PR uses the transforms machinery :)
I had to edit
TransformedDistribution'slog_probmethod to take into accountevent_shape- this is probably not the right way to do it, but a quick first try that makes it work.cc @fritzo