Skip to content

Batched gradient computation w/ vmap, feature rollup #49562

@zou3519

Description

@zou3519

🚀 Feature

Support efficient batched gradient computation. The use cases for this are efficient jacobian and hessian computation. The way we support batched gradient computation is by performing a vmap over the backward pass created by the autograd engine.

Feature rollup

Here is a list of (remaining) operators used in backward formulas. For each operator, we would like to write a batching rule that defines how to perform the operation on inputs with an additional batch dimension (see Context for more details).

Comparison ops (#50364)

  • torch.eq
  • torch.gt
  • torch.ge
  • torch.le
  • torch.lt
  • torch.ne

(in-place) Arithmetic

  • Tensor.add_
  • Tensor.mul_
  • Tensor.sub_
  • Tensor.div_

Factory functions

  • torch.empty_like
  • torch.zeros_like
  • torch.ones_like
  • torch.full_like

Miscellaneous torch operations

  • Tensor.all
  • Tensor.any
  • Tensor.copy_
  • torch.sum
  • torch.mean
  • torch.flip
  • torch.ger
  • torch.roll
  • torch.rot90
  • torch.tril
  • Tensor.tril_
  • torch.triu
  • Tensor.triu_
  • Tensor.type_as
  • torch.where
  • torch.masked_select
  • Tensor.masked_fill_
  • Tensor.masked_fill
  • torch.cumsum

Miscellaneous torch.nn

  • F.adaptive_avg_pool1d
  • F.adaptive_avg_pool2d
  • F.adaptive_avg_pool3d
  • F.avg_pool1d
  • F.avg_pool2d
  • F..avg_pool3d
  • F.pad

Potentially trickier ones:

  • torch.gather
  • torch.diag_embed
  • torch.take
  • torch.matmul

Backwards operators (these don't appear in our Python API)

  • aten::tanh_backward
  • aten::logit_backward
  • aten::masked_select_backward
  • aten::avg_pool2d_backward
  • aten::elu_backward
  • aten::hardtanh_backward
  • aten::threshold_backward [pytorch] add threshold_backward batching for vmap #49881
  • aten::rrelu_with_noise_backward
  • aten::mse_loss_backward
  • aten::hardshrink_backward
  • aten::kl_div_backward
  • aten::l1_loss_backward
  • aten::smooth_l1_loss_backward
  • aten::leaky_relu_backward
  • aten::softplus_backward
  • aten::log_sigmoid_backward
  • aten::softshrink_backward

Backwards operators

Here is a list of operators that I'm not sure about the difficulty. You should not attempt these because they might become a wild goose chase:

  • aten::index_fill_.int_Scalar
  • aten::_sparse_coo_tensor_unsafe
  • aten::index_select
  • aten::index_select_backward
  • aten::cross
  • aten::cumprod_backward
  • aten::_ctc_loss_backward
  • aten::to_dense_backward
  • aten::matrix_exp_backward
  • aten::put_
  • aten::solve
  • aten::unfold_backward
  • aten::gather_backward
  • aten::_cdist_backward
  • aten::take_backward
  • aten::_index_put_impl_
  • aten::scatter_.value
  • aten::_trilinear
  • aten::index.Tensor
  • aten::col2im
  • aten::im2col
  • aten::constant_pad_nd
  • aten::index_add_
  • aten::logcumsumexp
  • aten::max_unpool2d
  • aten::nll_loss
  • aten::nll_loss2d
  • aten::reflection_pad1d
  • aten::reflection_pad2d
  • aten::reflection_pad1d_backward
  • aten::reflection_pad2d_backward
  • aten::replication_pad1d
  • aten::replication_pad2d
  • aten::replication_pad3d
  • aten::replication_pad1d_backward
  • aten::replication_pad2d_backward
  • aten::replication_pad3d_backward
  • aten::triangular_solve
  • aten::soft_margin_loss_backward
  • aten::upsample_bicubic2d.vec
  • aten::upsample_bilinear2d.vec
  • aten::upsample_linear1d.vec
  • aten::upsample_nearest1d.vec
  • aten::upsample_nearest2d.vec
  • aten::upsample_nearest3d.vec
  • aten::upsample_trilinear3d.vec

For completeness, here is a list of operators you should not attempt. These require writing brand-new kernels for best performance.

  • aten::_softmax_backward_data
  • aten::_log_softmax_backward_data
  • aten::_convolution
  • Backward pass of: F.conv1d, F.conv2d, F.conv3d
  • Backward pass of: F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d
  • Backward pass of: nn.RNN, nn.LSTM, nn.GRU

The task

The task is to pick an operator on the list (or a group of them) and implement batching rules for them. I will add github handles next to the operators that are currently being worked on and mark completed ones as done.

How to reproduce

Pick an operator from the above list (e.g., Tensor.fill_). Attempt to run it inside a vmap:

>>> import torch
>>> from torch import vmap
>>> x = torch.randn(64, 3)
>>> result = vmap(lambda x: x.fill_(1))(x)
/scratch/rzou/pt/whiteboard-env/bin/ipython:4: UserWarning: Batching rule not implemented for aten::fill_.Scalar falling back to slow (for loop) implementation (Triggered internally
at  ../aten/src/ATen/BatchedFallback.cpp:68.)

It should raise a warning that the batching rule has not been implemented.

Context

vmap is a higher-order vectorization operator. It takes in a function f and returns a new function that maps the f over some dimension of the inputs. The new function has similar semantics to performing a for-loop over the mapped dimension:

inputs = torch.randn(N, 3)
f = lambda x: x.clamp(min=0)
expected = torch.stack([f(inputs[i]) for i in range(N)])
result = torch.vmap(f)(inputs)
assert torch.allclose(result, expected)

Instead of performing a for-loop to compute expected, the user could have vectorized their computation by straight up calling clamp: expected2 = inputs.clamp(min=0). However, for more complicated functions, like F.conv2d, vectorizing is not as obvious. In some cases, like efficient Jacobian/Hessian computation, vectorization isn't possible by the user because they cannot control what happens inside of the autograd engine. One of the targeted use cases for vmap is to speed up batched gradient computation in PyTorch.

How vmap works is that a "batching rule" is defined for each operator. The batching rule defines how to perform the operation on inputs with extra dimensions (that we refer to as the "batch dimensions").

Resources

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: vmaptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions