This repository was archived by the owner on Aug 21, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 108
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
25% Performance regression from v0.1.1 to 0.2.0 when calculating hessian #989
Copy link
Copy link
Open
Description
Hi developers,
After I upgraded functorch from v0.1.1 to 0.2.0, I noticed a 25% performance regression when calculating hessian, please check the following benchmark result and the attached benchmark script.
Please let me know if I did anything wrong, and also whether the perf regression could be fixed.
Thanks!
Benchmark result
Benchmark result on NVIDIA A100
# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred error: functorch: 0.00e+00
max hessian error: functorch: 0.00e+00
reference_hessian: 61.837 ms
functorch_hessian: 29.474 ms
# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred error: functorch: 1.49e-08
max hessian error: functorch: 0.00e+00
reference_hessian: 62.519 ms
functorch_hessian: 39.666 ms (0.75 X)
Benchmark result on NVIDIA A6000
# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred error: functorch: 1.49e-08
max hessian error: functorch: 0.00e+00
reference_hessian: 65.984 ms
functorch_hessian: 33.662 ms
# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred error: functorch: 1.86e-08
max hessian error: functorch: 0.00e+00
reference_hessian: 67.285 ms
functorch_hessian: 49.723 ms (0.68 X)
benchmark script
benchmark.py
import time
import argparse
from functorch import vmap, jacrev, jacfwd
import torch
import torch.nn as nn
torch.backends.cuda.matmul.allow_tf32 = False
_ = torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2 # x, y
D2 = 3 # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
run_backward = False
model = nn.Sequential(
nn.Linear(D1, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, D2),
).to(device)
def predict(x):
torch.cuda.nvtx.range_push("forward")
out = model(x)
torch.cuda.nvtx.range_pop()
return out, out # return two outputs is needed for jacrev auxiliary object
def reference_hessian():
x_ = x.clone().requires_grad_()
ones = torch.ones(B, device=x.device)
pred, _ = predict(x_)
jacobian_rows = [None] * D2
hessian_rows = [None] * (D2 * D1)
for i in range(D2):
torch.cuda.nvtx.range_push("autograd jacobian")
jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[
0
]
torch.cuda.nvtx.range_pop()
for i in range(D2):
for j in range(D1):
torch.cuda.nvtx.range_push("autograd hesian")
hessian_rows[i * D1 + j] = torch.autograd.grad(
jacobian_rows[i][:, j], x_, ones, create_graph=True
)[0]
torch.cuda.nvtx.range_pop()
jacobian = torch.stack(jacobian_rows) # [D2, B, D1]
hessian = torch.stack(hessian_rows) # [D2 * D1, B, D1]
if run_backward:
l = hessian.sum()
l.backward()
return hessian.transpose(0, 1), pred
def functorch_hessian():
x_ = x.clone().requires_grad_()
hessian, pred = vmap(
jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True),
in_dims=0,
)(
x_
) # [B, D2, D1, D1]
if run_backward:
l = hessian.sum()
l.backward()
return hessian, pred
def validate_result():
# test functorch result
ref_hes, ref_pred = reference_hessian()
ft_hes, ft_pred = functorch_hessian()
ref_hes = ref_hes.view_as(ft_hes)
print(f"max pred error: functorch: {(ref_pred - ft_pred).max():.2e}")
print(f"max hessian error: functorch: {(ref_hes - ft_hes).max():.2e}")
def benchmark(func):
N = 20
torch.cuda.synchronize()
start = time.time()
for i in range(N):
torch.cuda.nvtx.range_push(func.__name__)
_ = func()
torch.cuda.nvtx.range_pop()
torch.cuda.synchronize()
time_ms = ((time.time() - start) / N) * 1000
print(f"{func.__name__}: {time_ms:.3f} ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--backward", default=False, action="store_true")
args = parser.parse_args()
if args.backward:
run_backward = True
print("===== benchmark with backward =====")
else:
print("===== benchmark without backward =====")
validate_result()
# warm up
for i in range(10):
reference_hessian()
functorch_hessian()
# benchmark hessian
benchmark(reference_hessian)
benchmark(functorch_hessian)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels