Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

25% Performance regression from v0.1.1 to 0.2.0 when calculating hessian #989

@yueyericardo

Description

@yueyericardo

Hi developers,

After I upgraded functorch from v0.1.1 to 0.2.0, I noticed a 25% performance regression when calculating hessian, please check the following benchmark result and the attached benchmark script.

Please let me know if I did anything wrong, and also whether the perf regression could be fixed.
Thanks!

Benchmark result

Benchmark result on NVIDIA A100

# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred       error: functorch: 0.00e+00
max hessian    error: functorch: 0.00e+00
reference_hessian: 61.837 ms
functorch_hessian: 29.474 ms

# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred       error: functorch: 1.49e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 62.519 ms
functorch_hessian: 39.666 ms  (0.75 X)

Benchmark result on NVIDIA A6000

# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred       error: functorch: 1.49e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 65.984 ms
functorch_hessian: 33.662 ms

# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred       error: functorch: 1.86e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 67.285 ms
functorch_hessian: 49.723 ms (0.68 X)

benchmark script

benchmark.py

import time
import argparse
from functorch import vmap, jacrev, jacfwd
import torch
import torch.nn as nn

torch.backends.cuda.matmul.allow_tf32 = False


_ = torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2  # x, y
D2 = 3  # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
run_backward = False

model = nn.Sequential(
    nn.Linear(D1, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, D2),
).to(device)


def predict(x):
    torch.cuda.nvtx.range_push("forward")
    out = model(x)
    torch.cuda.nvtx.range_pop()
    return out, out  # return two outputs is needed for jacrev auxiliary object


def reference_hessian():
    x_ = x.clone().requires_grad_()
    ones = torch.ones(B, device=x.device)
    pred, _ = predict(x_)
    jacobian_rows = [None] * D2
    hessian_rows = [None] * (D2 * D1)
    for i in range(D2):
        torch.cuda.nvtx.range_push("autograd jacobian")
        jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[
            0
        ]
        torch.cuda.nvtx.range_pop()

    for i in range(D2):
        for j in range(D1):
            torch.cuda.nvtx.range_push("autograd hesian")
            hessian_rows[i * D1 + j] = torch.autograd.grad(
                jacobian_rows[i][:, j], x_, ones, create_graph=True
            )[0]
            torch.cuda.nvtx.range_pop()

    jacobian = torch.stack(jacobian_rows)  # [D2, B, D1]
    hessian = torch.stack(hessian_rows)  # [D2 * D1, B, D1]
    if run_backward:
        l = hessian.sum()
        l.backward()
    return hessian.transpose(0, 1), pred


def functorch_hessian():
    x_ = x.clone().requires_grad_()
    hessian, pred = vmap(
        jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True),
        in_dims=0,
    )(
        x_
    )  # [B, D2, D1, D1]
    if run_backward:
        l = hessian.sum()
        l.backward()
    return hessian, pred


def validate_result():
    # test functorch result
    ref_hes, ref_pred = reference_hessian()
    ft_hes, ft_pred = functorch_hessian()
    ref_hes = ref_hes.view_as(ft_hes)
    print(f"max pred       error: functorch: {(ref_pred - ft_pred).max():.2e}")
    print(f"max hessian    error: functorch: {(ref_hes - ft_hes).max():.2e}")


def benchmark(func):
    N = 20

    torch.cuda.synchronize()
    start = time.time()

    for i in range(N):
        torch.cuda.nvtx.range_push(func.__name__)
        _ = func()
        torch.cuda.nvtx.range_pop()

    torch.cuda.synchronize()
    time_ms = ((time.time() - start) / N) * 1000
    print(f"{func.__name__}: {time_ms:.3f} ms")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-b", "--backward", default=False, action="store_true")
    args = parser.parse_args()
    if args.backward:
        run_backward = True
        print("===== benchmark with backward =====")
    else:
        print("===== benchmark without backward =====")

    validate_result()

    # warm up
    for i in range(10):
        reference_hessian()
        functorch_hessian()

    # benchmark hessian
    benchmark(reference_hessian)
    benchmark(functorch_hessian)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions