Bijectors.jl
Bijectors.jl copied to clipboard
Zygote AD & `logpdf` for transformed multivariate
I've found that Zygote fails to compute gradients when using the method of logpdf defined here
Here's a MWE:
using Bijectors
using DistributionsAD
using Flux
using Zygote
d = MvNormal(zeros(2), ones(2))
b = PlanarLayer(2)
flow = transformed(d, b)
x = [0.42 0.24; 0.42 0.24]
"""Use the optimised `logpdf` call."""
loss_(flow, x) = -sum(logpdf(flow, x))
"""Rearrange to use default `logpdf` in `Distributions`."""
function loss2_(flow, x)
things = map(eachcol(x)) do obs
logpdf(flow, obs)
end
return -sum(things)
end
@show loss_(flow, x)
@show loss2_(flow, x)
println()
gs = gradient(() -> loss_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]]
gs = gradient(() -> loss2_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]];
With output:
loss_(flow, x) = 3.089176357252711
loss2_(flow, x) = 3.089176357252711
gs.grads[(Flux.params(b))[1]] = nothing
gs.grads[(Flux.params(b))[1]] = [-2.603210756288831, -4.3264084139896095]
tested on Bijectors v0.10.0.
I'm not sure, but maybe the optimised dispatch for logpdf (or some of the methods called within) need additional chainrules support?