Skip to content

Add a vectorize flag to torch.autograd.functional.{jacobian, hessian} #50584

@zou3519

Description

@zou3519

🚀 Feature

Add a vectorize flag to torch.autograd.functional.jacobian and torch.autograd.functional.hessian (default: False). Under the hood, the vectorize flag uses vmap as the backend to compute the jacobian and hessian, respectively, providing speedups to users.

e.g.

import torch
x = torch.randn(5, requires_grad=True)
f = lambda x: x ** 2
expected = torch.autograd.functional.jacobian(f, x)
jac = torch.autograd.functional.jacobian(f, x, vectorize=True)
assert torch.allclose(jac, expected)

Motivation

Jacobian computation (and by extension, hessian computation) in PyTorch today involves invoking torch.autograd.grad once per row of the jacobian. The following explains the procedure used by torch.autograd.functional.jacobian at a high-level:

  • Our autograd engine computes vector-jacobian products without fully materializing a jacobian.
  • To compute the first row of the jacobian, we use a vector-jacobian-product with (1, 0, 0, ...) as the vector.
  • To compute the ith row of the jacobian, we use the ith unit vector.
  • Finally, when we have all of the rows of the jacobian, we stack all of them together

Assuming a N by N jacobian, we need to invoke the autograd engine N times. The amount of overhead here (due to tensor creation and operator overhead) can be and is significant in a number of use cases like bayesian logistic regression.

Alternatives

Instead of updating jacobian and hessian, we can expose vmap directly and tell users to use (pseudocode) vmap(vjp) to compute efficient jacobians. However, this would create a "trap" in our API where a user using autograd.functional.jacobian cannot benefit from these performance improvements.

Additional context

Not all batching rules that are needed for jacobian and hessian computation are implemented: #49562. A good number of these may require writing new CUDA kernels from scratch. We'd like to offer the vectorize=False API so that we can begin speeding up user code without having them wait until we are finished with writing a substantial number of batching rules.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer

Metadata

Metadata

Assignees

Labels

enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: autogradRelated to torch.autograd, and the autograd engine in generalmodule: vmaptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions