Skip to content

Add Extrapolation to Bernstein Polynomial Flow #36

@MArpogaus

Description

@MArpogaus

Description

As discussed in #32, we now implemented linear extrapolation outside the bounds of the Bernstein Polynomial.
The feature becomes active if linear=True.

Here is a simple Python script to visualize the resulting effect:

# %% Imports
import torch

from matplotlib import pyplot as plt
from zuko.transforms import BernsteinTransform

# %% Globals
M = 10
batch_size = 10
torch.manual_seed(1)
theta = torch.rand(size=(M,)) * 500  # creates a random parameter vector

# %% Sigmoid
bpoly = BernsteinTransform(theta=theta, linear=False)

x = torch.linspace(-15, 15, 2000)
y = bpoly(x)

adj = bpoly.log_abs_det_jacobian(x, y).detach()
J = torch.diag(torch.autograd.functional.jacobian(bpoly, x)).abs().log()

# %% Plot
fig, axs = plt.subplots(2, sharex=True)
fig.suptitle("Bernstein polynomial with Sigmoid")
axs[0].plot(x, y, label="Bernstein polynomial")
axs[0].scatter(
    torch.linspace(-10, 10, bpoly.order + 1),
    bpoly.theta.numpy().flatten(),
    label="Bernstein coefficients",
)
axs[0].legend()
axs[1].plot(x, adj, label="ladj")
# axs[1].scatter(
#     torch.linspace(-10, 10, bpoly.order),
#     bpoly.dtheta.numpy().flatten(),
#     label="dtheta",
# )
axs[1].plot(x, J, label="ladj (autograd)")
axs[1].legend()
fig.tight_layout()
fig.savefig("sigmoid.png")

sigmoid

# %% Extrapolataion
bpoly = BernsteinTransform(theta=theta, linear=True)

x = torch.linspace(-15, 15, 2000)
y = bpoly(x)

adj = bpoly.log_abs_det_jacobian(x, y).detach()
J = torch.diag(torch.autograd.functional.jacobian(bpoly, x)).abs().log()

# %% Plot

fig, axs = plt.subplots(2, sharex=True)
fig.suptitle("Bernstein polynomial with linear extrapolation")
axs[0].plot(x, y, label="Bernstein polynomial")
axs[0].scatter(
    torch.linspace(-10, 10, bpoly.order + 1),
    bpoly.theta.numpy().flatten(),
    label="Bernstein coefficients",
)
axs[0].legend()
axs[1].plot(x, adj, label="ladj")
# axs[1].scatter(
#     torch.linspace(-10, 10, bpoly.order),
#     bpoly.dtheta.numpy().flatten(),
#     label="dtheta",
# )
axs[1].plot(x, J, label="ladj (autograd)")
axs[1].legend()
fig.tight_layout()
fig.savefig("linear.png")

linear

This makes the BPF Implementation more robust to data laying outside the domain.of the Bernstein Polynomial,without the need for thenon linear sigmoid function.

@oduerr Do you have anything else to add?

Implementation

The implementation can be found in the bpf_extrapolation branch of my fork.

My changes specifically include

  1. Optional linear extrapolation in the call method
  2. Custom Implementation of log_abs_det_jacobian since the gradient dos not seam to pass through the torch.where statement and to improve numerical stability ind the sigmoidal case.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions