🚀 Feature
Support efficient batched gradient computation. The use cases for this are efficient jacobian and hessian computation. The way we support batched gradient computation is by performing a vmap over the backward pass created by the autograd engine.
Feature rollup
Here is a list of (remaining) operators used in backward formulas. For each operator, we would like to write a batching rule that defines how to perform the operation on inputs with an additional batch dimension (see Context for more details).
Comparison ops (#50364)
(in-place) Arithmetic
Factory functions
Miscellaneous torch operations
Miscellaneous torch.nn
Potentially trickier ones:
Backwards operators (these don't appear in our Python API)
Backwards operators
Here is a list of operators that I'm not sure about the difficulty. You should not attempt these because they might become a wild goose chase:
For completeness, here is a list of operators you should not attempt. These require writing brand-new kernels for best performance.
The task
The task is to pick an operator on the list (or a group of them) and implement batching rules for them. I will add github handles next to the operators that are currently being worked on and mark completed ones as done.
How to reproduce
Pick an operator from the above list (e.g., Tensor.fill_). Attempt to run it inside a vmap:
>>> import torch
>>> from torch import vmap
>>> x = torch.randn(64, 3)
>>> result = vmap(lambda x: x.fill_(1))(x)
/scratch/rzou/pt/whiteboard-env/bin/ipython:4: UserWarning: Batching rule not implemented for aten::fill_.Scalar falling back to slow (for loop) implementation (Triggered internally
at ../aten/src/ATen/BatchedFallback.cpp:68.)
It should raise a warning that the batching rule has not been implemented.
Context
vmap is a higher-order vectorization operator. It takes in a function f and returns a new function that maps the f over some dimension of the inputs. The new function has similar semantics to performing a for-loop over the mapped dimension:
inputs = torch.randn(N, 3)
f = lambda x: x.clamp(min=0)
expected = torch.stack([f(inputs[i]) for i in range(N)])
result = torch.vmap(f)(inputs)
assert torch.allclose(result, expected)
Instead of performing a for-loop to compute expected, the user could have vectorized their computation by straight up calling clamp: expected2 = inputs.clamp(min=0). However, for more complicated functions, like F.conv2d, vectorizing is not as obvious. In some cases, like efficient Jacobian/Hessian computation, vectorization isn't possible by the user because they cannot control what happens inside of the autograd engine. One of the targeted use cases for vmap is to speed up batched gradient computation in PyTorch.
How vmap works is that a "batching rule" is defined for each operator. The batching rule defines how to perform the operation on inputs with extra dimensions (that we refer to as the "batch dimensions").
Resources
🚀 Feature
Support efficient batched gradient computation. The use cases for this are efficient jacobian and hessian computation. The way we support batched gradient computation is by performing a
vmapover the backward pass created by the autograd engine.Feature rollup
Here is a list of (remaining) operators used in backward formulas. For each operator, we would like to write a batching rule that defines how to perform the operation on inputs with an additional batch dimension (see Context for more details).
Comparison ops (#50364)
(in-place) Arithmetic
Factory functions
Miscellaneous torch operations
Miscellaneous torch.nn
Potentially trickier ones:
Backwards operators (these don't appear in our Python API)
Backwards operators
Here is a list of operators that I'm not sure about the difficulty. You should not attempt these because they might become a wild goose chase:
aten::_index_put_impl_For completeness, here is a list of operators you should not attempt. These require writing brand-new kernels for best performance.
The task
The task is to pick an operator on the list (or a group of them) and implement batching rules for them. I will add github handles next to the operators that are currently being worked on and mark completed ones as done.
How to reproduce
Pick an operator from the above list (e.g., Tensor.fill_). Attempt to run it inside a vmap:
It should raise a warning that the batching rule has not been implemented.
Context
vmapis a higher-order vectorization operator. It takes in a functionfand returns a new function that maps thefover some dimension of the inputs. The new function has similar semantics to performing a for-loop over the mapped dimension:Instead of performing a for-loop to compute
expected, the user could have vectorized their computation by straight up callingclamp:expected2 = inputs.clamp(min=0). However, for more complicated functions, likeF.conv2d, vectorizing is not as obvious. In some cases, like efficient Jacobian/Hessian computation, vectorization isn't possible by the user because they cannot control what happens inside of the autograd engine. One of the targeted use cases for vmap is to speed up batched gradient computation in PyTorch.How vmap works is that a "batching rule" is defined for each operator. The batching rule defines how to perform the operation on inputs with extra dimensions (that we refer to as the "batch dimensions").
Resources