Skip to content

Commit fdf02ef

Browse files
albanDfacebook-github-bot
authored andcommitted
Add base forward grad logic (#49097)
Summary: Pull Request resolved: #49097 RFC: pytorch/rfcs#11 This PR add the basic logic to handle forward grad as dual Tensors. It contains the following: - Mechanism to save dual state on a Tensor and clear it up when the dual level ends - C++ and python user facing API - Updated view system that is able to track both forward and backward views The current PR has the following limitations: - Extensive tests are in the next PR in the stack as formulas are needed to write full tests. - Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack) - Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR. - We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise. - We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise. Reading guide: - Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). This introduces the new ViewInfo to hold view informations shared for forward and backward. It also updates the differentiable view meta to use this. And it updates the as_view function to handle both forward and backward view. - New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). EDIT: These files also contain the new flag to globally disable forward AD that allows us to reduce performance issues while this is in development. - Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677) - API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243) - c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9) - python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d) - python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8) - c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3) - Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433) - Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030) Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D25607503 Pulled By: albanD fbshipit-source-id: f1396290de1d75760f3d380c43cdd56e86fa6099
1 parent befe337 commit fdf02ef

37 files changed

Lines changed: 1442 additions & 153 deletions

aten/src/ATen/core/Formatting.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,11 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
292292
stream << ", axis: " << tensor_.q_per_channel_axis();
293293
}
294294
}
295+
296+
auto& fw_grad = tensor.fw_grad(/* level */ 0);
297+
if (fw_grad.defined()) {
298+
stream << ", tangent:" << std::endl << fw_grad;
299+
}
295300
stream << " ]";
296301
}
297302
return stream;

aten/src/ATen/core/NamedRegistrations.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,4 +510,5 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
510510
m.impl("_version", CppFunction::makeFallthrough());
511511
m.impl("requires_grad_", CppFunction::makeFallthrough());
512512
m.impl("retain_grad", CppFunction::makeFallthrough());
513+
m.impl("_fw_primal", CppFunction::makeFallthrough());
513514
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include <ATen/ATen.h>
2+
3+
namespace at {
4+
namespace native {
5+
6+
/// This function can be used to create a dual Tensor that holds a tangent to compute forward mode gradients.
7+
/// Note that the dual Tensor's primal is a view of the given primal and the given tangent is used as-is.
8+
/// This function is backward differentiable.
9+
at::Tensor make_dual(const at::Tensor& primal, const at::Tensor& tangent, int64_t level) {
10+
TORCH_CHECK(!primal.fw_grad(level).defined(), "Making a dual Tensor based on a Tensor that "
11+
"already has a forward gradient at the same level ", level, " is not supported.");
12+
13+
auto dual_tensor = primal.view(primal.sizes());
14+
dual_tensor.set_fw_grad(tangent, level, /* is_inplace_op */ false);
15+
return dual_tensor;
16+
}
17+
18+
/// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal
19+
/// is a view of the dual and the tangent is returned as is.
20+
/// This function is backward differentiable.
21+
std::tuple<at::Tensor, at::Tensor> unpack_dual(const at::Tensor& tensor, int64_t level) {
22+
return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor.fw_grad(level));
23+
}
24+
25+
} // namespace native
26+
27+
} // namespace at

aten/src/ATen/native/VariableMethodStubs.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,9 @@ void retain_grad(Tensor& self) {
4040
AT_ERROR("retain_grad is not implemented for Tensor");
4141
}
4242

43+
Tensor _fw_primal(const Tensor& self, int64_t level) {
44+
AT_ERROR("_fw_primal is not implemented for Tensor");
45+
}
46+
4347
} // namespace native
4448
} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@
105105
manual_kernel_registration: True
106106
variants: method
107107

108+
- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a)
109+
use_c10_dispatcher: full
110+
variants: method
111+
dispatch:
112+
DefaultBackend: _fw_primal
113+
114+
- func: make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)
115+
use_c10_dispatcher: full
116+
variants: function
117+
118+
- func: unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)
119+
use_c10_dispatcher: full
120+
variants: function
121+
108122
- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
109123
use_c10_dispatcher: full
110124
variants: method

aten/src/ATen/templates/TensorBody.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,23 @@ class TORCH_API Tensor {
599599
return impl_->grad();
600600
}
601601

602+
// The Forward AD API functions below are low level and are not to be used by end
603+
// users who should use the API provided in torch/csrc/autograd.h
604+
605+
/// This function returns the forward gradient for this Tensor at the given level.
606+
const Tensor& fw_grad(uint64_t level) const {
607+
return impl_->fw_grad(level, *this);
608+
}
609+
610+
/// This function can be used to set the value of the forward grad.
611+
/// Note that the given new_grad might not be used directly if it has different
612+
/// metadata (size/stride/storage offset) compared to this Tensor. In that case,
613+
/// new_grad content will be copied into a new Tensor
614+
void set_fw_grad(const Tensor& new_grad, uint64_t level, bool is_inplace_op) {
615+
impl_->set_fw_grad(new_grad, *this, level, is_inplace_op);
616+
}
617+
618+
602619
// STOP. Thinking of adding a method here, which only makes use
603620
// of other ATen methods? Define it in native_functions.yaml.
604621

c10/core/TensorImpl.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ const at::Tensor& TensorImpl::grad() const {
4444
return autograd_meta_->grad();
4545
}
4646

47+
const at::Tensor& TensorImpl::fw_grad(uint64_t level, const at::Tensor& self) const {
48+
// See TensorImpl::grad() above for explanation about the line below
49+
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
50+
return autograd_meta_->fw_grad(level, self);
51+
}
52+
53+
void TensorImpl::set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) {
54+
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
55+
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
56+
}
57+
4758
TensorImpl::TensorImpl(
4859
Storage&& storage,
4960
DispatchKeySet key_set,

c10/core/TensorImpl.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ struct C10_API AutogradMetaInterface {
136136
virtual bool requires_grad() const = 0;
137137
virtual at::Tensor& mutable_grad() = 0;
138138
virtual const at::Tensor& grad() const = 0;
139+
virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const = 0;
140+
virtual void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) = 0;
139141
virtual ~AutogradMetaInterface();
140142
};
141143

@@ -598,6 +600,42 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
598600
*/
599601
const at::Tensor& grad() const;
600602

603+
/**
604+
* Return the accumulated gradient of a tensor. This gradient is computed
605+
* using forward mode AD.
606+
*
607+
* This is an internal API that should never be used by end users.
608+
*
609+
* The API is as follows:
610+
* - "level" allows to specify the level of forward AD nesting for which the
611+
* gradient should be returned. Note that since levels are not fully
612+
* supported yet, this argument should be 0. See documentation for
613+
* torch::autograd::enter_dual_level for more details about forward AD nesting.
614+
* - "self" should represent the Tensor whose forward grad is accessed. It is
615+
* required when dealing with view.
616+
*/
617+
const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const;
618+
619+
/**
620+
* Sets the forward gradient for this Tensor.
621+
* The given Tensor might not be used directly and its content will be copied.
622+
*
623+
* This is an internal API that should never be used by end users.
624+
*
625+
* The API is as follows:
626+
* - "new_grad" is a Tensor containing the new value of the gradient that should
627+
* be set
628+
* - "self" should reprensent the Tensor whose forward grad is accessed. It is
629+
* required when dealing with view.
630+
* - "level" allows to specify the level of forward AD nesting for which the
631+
* gradient should be set. Note that since levels are not fully supported
632+
* yet, this argument should be 0. See documentation for torch::autograd::enter_dual_level
633+
* for more details about forward AD nesting.
634+
* - "is_inplace_op" is a boolean flag that tells if this gradient was generated
635+
* by an inplace operation or an out of place one. This allows better error checking.
636+
*/
637+
void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op);
638+
601639
/**
602640
* Return a typed data pointer to the actual data which this tensor refers to.
603641
* This checks that the requested type (from the template parameter) matches

test/test_autograd.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck)
3636
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
3737
from torch.autograd.function import InplaceFunction
38+
import torch.autograd.forward_ad as fwAD
3839
from torch.testing import randn_like
3940
from torch.testing._internal.common_methods_invocations import (method_tests,
4041
create_input, unpack_variables,
@@ -5326,6 +5327,26 @@ def fn(a, dim0_size=5):
53265327

53275328
self.assertEqual(x.grad, y.grad)
53285329

5330+
def test_view_with_multi_output(self):
5331+
x = torch.randn(2, 2, 2, dtype=torch.double)
5332+
5333+
x1 = torch.view_as_complex(x)
5334+
# Taking an invalid view should always be allowed as long as it is not
5335+
# modified inplace
5336+
res = x1.unbind(0)
5337+
5338+
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
5339+
res[0] += torch.rand(2, requires_grad=True)
5340+
5341+
x.requires_grad_(True)
5342+
x1 = torch.view_as_complex(x)
5343+
# Taking an invalid view should always be allowed as long as it is not
5344+
# modified inplace
5345+
res = x1.unbind(0)
5346+
5347+
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
5348+
res[0] += torch.rand(2, requires_grad=True)
5349+
53295350
def as_identity(self):
53305351
# view_as_real and view_as_complex behavior should be like an identity
53315352
def func(z):
@@ -6324,6 +6345,66 @@ def foo(a):
63246345
self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1))
63256346
self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0))
63266347

6348+
class TestAutogradForwardMode(TestCase):
6349+
def test_forward_level_cleanup(self):
6350+
import weakref
6351+
6352+
def get_tensor_and_weak_ref():
6353+
# Helper function to get a Tensor and a weak ref that tells us
6354+
# if the c++ version of this Tensor is still alive or not.
6355+
#
6356+
# Create the following reference chain to do so:
6357+
# - python Tensor t
6358+
# - c++ Tensor corresponding by t
6359+
# - c++ Node corresponding to t.grad_fn
6360+
# - python dict of metadata from this Node
6361+
# - an object in this dict that we can take a weakref of
6362+
6363+
6364+
# Create a new Tensor and Node
6365+
t = torch.rand(2, requires_grad=True).clone()
6366+
# Create the metadata dict
6367+
meta_dict = t.grad_fn.metadata
6368+
# Create the object in the dict
6369+
6370+
class Foo(object):
6371+
pass
6372+
my_obj = Foo()
6373+
meta_dict[0] = my_obj
6374+
6375+
# After exiting this function, the python Tensor t is the only
6376+
# thing keeping ref alive
6377+
ref = weakref.ref(my_obj)
6378+
return t, ref
6379+
6380+
# Sanity check that the helper function works as expected
6381+
t, t_ref = get_tensor_and_weak_ref()
6382+
self.assertIsNotNone(t_ref())
6383+
6384+
del t
6385+
self.assertIsNone(t_ref())
6386+
6387+
# Main test code
6388+
foo = torch.rand(2)
6389+
6390+
with fwAD.dual_level():
6391+
tangent, tangent_ref = get_tensor_and_weak_ref()
6392+
self.assertIsNotNone(tangent_ref())
6393+
6394+
dual = fwAD.make_dual(foo, tangent)
6395+
self.assertIsNotNone(tangent_ref())
6396+
6397+
# Make sure that the tangent we provided has been re-used as is
6398+
self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent)
6399+
6400+
# Make sure that dual is keeping the tangent alive
6401+
del tangent
6402+
self.assertIsNotNone(tangent_ref())
6403+
6404+
# Make sure that the dual level does not keep the c++
6405+
# version of the tangent alive
6406+
del dual
6407+
self.assertIsNone(tangent_ref())
63276408

63286409
# Generic device type autograd tests.
63296410
class TestAutogradDeviceType(TestCase):

test/test_namedtuple_return_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
all_operators_with_namedtuple_return = {
1313
'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig',
1414
'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq',
15-
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh'
15+
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual"
1616
}
1717

1818

@@ -65,6 +65,7 @@ def test_namedtuple_return(self):
6565
op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True),
6666
op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True),
6767
op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True),
68+
op(operators=['unpack_dual'], input=(a, 0), names=('primal', 'tangent'), hasout=False),
6869
]
6970

7071
for op in operators:
@@ -75,7 +76,9 @@ def test_namedtuple_return(self):
7576
for i, name in enumerate(op.names):
7677
self.assertIs(getattr(ret, name), ret[i])
7778
else:
78-
ret = getattr(a, f)(*op.input)
79+
# Handle op that are not methods
80+
func = getattr(a, f) if hasattr(a, f) else getattr(torch, f)
81+
ret = func(*op.input)
7982
for i, name in enumerate(op.names):
8083
self.assertIs(getattr(ret, name), ret[i])
8184
if op.hasout:

0 commit comments

Comments
 (0)