Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

Zygote AD & `logpdf` for transformed multivariate

Open tpgillam opened this issue 4 years ago • 0 comments

I've found that Zygote fails to compute gradients when using the method of logpdf defined here

Here's a MWE:

using Bijectors
using DistributionsAD
using Flux
using Zygote

d = MvNormal(zeros(2), ones(2))
b = PlanarLayer(2)
flow = transformed(d, b)

x = [0.42 0.24; 0.42 0.24]

"""Use the optimised `logpdf` call."""
loss_(flow, x) = -sum(logpdf(flow, x))

"""Rearrange to use default `logpdf` in `Distributions`."""
function loss2_(flow, x)
    things = map(eachcol(x)) do obs
        logpdf(flow, obs)
    end
    return -sum(things)
end

@show loss_(flow, x)
@show loss2_(flow, x)

println()

gs = gradient(() -> loss_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]]

gs = gradient(() -> loss2_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]];

With output:

loss_(flow, x) = 3.089176357252711
loss2_(flow, x) = 3.089176357252711

gs.grads[(Flux.params(b))[1]] = nothing
gs.grads[(Flux.params(b))[1]] = [-2.603210756288831, -4.3264084139896095]

tested on Bijectors v0.10.0.

I'm not sure, but maybe the optimised dispatch for logpdf (or some of the methods called within) need additional chainrules support?

tpgillam avatar Apr 13 '22 08:04 tpgillam