-
Notifications
You must be signed in to change notification settings - Fork 35
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
Description
As discussed in #32, we now implemented linear extrapolation outside the bounds of the Bernstein Polynomial.
The feature becomes active if linear=True.
Here is a simple Python script to visualize the resulting effect:
# %% Imports
import torch
from matplotlib import pyplot as plt
from zuko.transforms import BernsteinTransform
# %% Globals
M = 10
batch_size = 10
torch.manual_seed(1)
theta = torch.rand(size=(M,)) * 500 # creates a random parameter vector
# %% Sigmoid
bpoly = BernsteinTransform(theta=theta, linear=False)
x = torch.linspace(-15, 15, 2000)
y = bpoly(x)
adj = bpoly.log_abs_det_jacobian(x, y).detach()
J = torch.diag(torch.autograd.functional.jacobian(bpoly, x)).abs().log()
# %% Plot
fig, axs = plt.subplots(2, sharex=True)
fig.suptitle("Bernstein polynomial with Sigmoid")
axs[0].plot(x, y, label="Bernstein polynomial")
axs[0].scatter(
torch.linspace(-10, 10, bpoly.order + 1),
bpoly.theta.numpy().flatten(),
label="Bernstein coefficients",
)
axs[0].legend()
axs[1].plot(x, adj, label="ladj")
# axs[1].scatter(
# torch.linspace(-10, 10, bpoly.order),
# bpoly.dtheta.numpy().flatten(),
# label="dtheta",
# )
axs[1].plot(x, J, label="ladj (autograd)")
axs[1].legend()
fig.tight_layout()
fig.savefig("sigmoid.png")# %% Extrapolataion
bpoly = BernsteinTransform(theta=theta, linear=True)
x = torch.linspace(-15, 15, 2000)
y = bpoly(x)
adj = bpoly.log_abs_det_jacobian(x, y).detach()
J = torch.diag(torch.autograd.functional.jacobian(bpoly, x)).abs().log()
# %% Plot
fig, axs = plt.subplots(2, sharex=True)
fig.suptitle("Bernstein polynomial with linear extrapolation")
axs[0].plot(x, y, label="Bernstein polynomial")
axs[0].scatter(
torch.linspace(-10, 10, bpoly.order + 1),
bpoly.theta.numpy().flatten(),
label="Bernstein coefficients",
)
axs[0].legend()
axs[1].plot(x, adj, label="ladj")
# axs[1].scatter(
# torch.linspace(-10, 10, bpoly.order),
# bpoly.dtheta.numpy().flatten(),
# label="dtheta",
# )
axs[1].plot(x, J, label="ladj (autograd)")
axs[1].legend()
fig.tight_layout()
fig.savefig("linear.png")This makes the BPF Implementation more robust to data laying outside the domain.of the Bernstein Polynomial,without the need for thenon linear sigmoid function.
@oduerr Do you have anything else to add?
Implementation
The implementation can be found in the bpf_extrapolation branch of my fork.
My changes specifically include
- Optional linear extrapolation in the call method
- Custom Implementation of log_abs_det_jacobian since the gradient dos not seam to pass through the
torch.wherestatement and to improve numerical stability ind the sigmoidal case.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request

