Skip to content

Commit 09a1b1c

Browse files
albanDfacebook-github-bot
authored andcommitted
Forward AD formulas batch 1 (#57768)
Summary: Pull Request resolved: #57768 Note that this PR implements formulas only for ops that are supported by OpInfo. Test Plan: Imported from OSS Reviewed By: zou3519, malfet Differential Revision: D28387766 Pulled By: albanD fbshipit-source-id: b4ba1cf1ac1dfd46cdd889385c9c2d5df3cf7a71
1 parent b4f3a98 commit 09a1b1c

4 files changed

Lines changed: 48 additions & 8 deletions

File tree

tools/autograd/derivatives.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@
182182

183183
- name: acos(Tensor self) -> Tensor
184184
self: grad * -((-self * self + 1).rsqrt()).conj()
185+
result: auto_element_wise
185186

186187
- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
187188
self: handle_r_to_c(self.scalar_type(), grad)
@@ -190,26 +191,31 @@
190191

191192
- name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
192193
self: handle_r_to_c(self.scalar_type(), grad)
194+
result: self_t
193195

194196
- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
195197
self: maybe_multiply(grad, beta.conj())
196198
batch1: grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2).conj()) * alpha.conj()
197199
batch2: batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })) * alpha.conj()
200+
result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha)
198201

199202
- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
200203
self: handle_r_to_c(self.scalar_type(), grad)
201204
tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj())
202205
tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj())
206+
result: self_t + maybe_multiply(tensor1_t / tensor2_p, value) - maybe_multiply(tensor2_t * (tensor1_p / tensor2_p) / tensor2_p, value)
203207

204208
- name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
205209
self: handle_r_to_c(self.scalar_type(), grad)
206210
tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (tensor2 * value).conj())
207211
tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj())
212+
result: self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value)
208213

209214
- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
210215
self: maybe_multiply(grad, beta.conj())
211216
mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)
212217
mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)
218+
result: maybe_multiply(self_t, beta) + maybe_multiply(mat1_t.mm(mat2_p), alpha) + maybe_multiply(mat1_p.mm(mat2_t), alpha)
213219

214220
- name: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor
215221
self: maybe_multiply(grad, beta)
@@ -220,20 +226,24 @@
220226
self: maybe_multiply(grad, beta.conj())
221227
mat: grad.ger(vec.conj()) * alpha.conj()
222228
vec: mat.t().conj().mv(grad) * alpha.conj()
229+
result: maybe_multiply(self_t, beta) + maybe_multiply(mat_t.mv(vec_p), alpha) + maybe_multiply(mat_p.mv(vec_t), alpha)
223230

224231
- name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
225232
self: maybe_multiply(grad, beta.conj())
226233
vec1: grad.mv(vec2.conj()) * alpha.conj()
227234
vec2: grad.t().mv(vec1.conj()) * alpha.conj()
235+
result: maybe_multiply(self_t, beta) + maybe_multiply(vec1_t.outer(vec2_p), alpha) + maybe_multiply(vec1_p.outer(vec2_t), alpha)
228236

229237
- name: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor
230238
theta: affine_grid_generator_backward(grad, size, align_corners)
231239

232240
- name: alias(Tensor(a) self) -> Tensor(a)
233241
self: grad
242+
result: self_t
234243

235244
- name: angle(Tensor self) -> Tensor
236245
self: angle_backward(grad, self)
246+
result: handle_r_to_c(result.scalar_type(), angle_backward(self_t, self_p))
237247

238248
# The four items below are necessary because TensorIterator doesn't work on
239249
# Variables (codegen does not unwrap the input Tensor for all() and any() ).
@@ -251,18 +261,21 @@
251261

252262
- name: acosh(Tensor self) -> Tensor
253263
self: grad * (self.pow(2) - 1).rsqrt().conj()
264+
result: auto_element_wise
254265

255266
- name: acosh_(Tensor(a!) self) -> Tensor(a!)
256267
self: not_implemented("inplace version of acosh")
257268

258269
- name: asinh(Tensor self) -> Tensor
259270
self: grad * (self.pow(2) + 1).rsqrt().conj()
271+
result: auto_element_wise
260272

261273
- name: asinh_(Tensor(a!) self) -> Tensor(a!)
262274
self: not_implemented("inplace version of asinh")
263275

264276
- name: atanh(Tensor self) -> Tensor
265277
self: grad * 1 / (1 - self.pow(2)).conj()
278+
result: auto_element_wise
266279

267280
- name: atanh_(Tensor(a!) self) -> Tensor(a!)
268281
self: not_implemented("inplace version of atanh")
@@ -272,9 +285,11 @@
272285

273286
- name: asin(Tensor self) -> Tensor
274287
self: grad * (-self * self + 1).rsqrt().conj()
288+
result: auto_element_wise
275289

276290
- name: atan(Tensor self) -> Tensor
277291
self: grad / (self * self + 1).conj()
292+
result: auto_element_wise
278293

279294
- name: atan2(Tensor self, Tensor other) -> Tensor
280295
self, other: atan2_backward(grad, self, other, grad_input_mask)
@@ -362,6 +377,7 @@
362377

363378
- name: _conj(Tensor self) -> Tensor
364379
self: grad.conj()
380+
result: self_t.conj()
365381

366382
- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
367383
self: copysign_tensor_self_backward(grad, self, result)

tools/autograd/gen_variable_type.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -752,10 +752,10 @@ def emit_fw_derivatives() -> List[str]:
752752
# Handle functions like stack
753753
# For these, we don't unpack anything and always call the user function
754754
if not (len(differentiable_inputs) == 1 and is_tensor_list_type(differentiable_inputs[0].type)):
755-
raise RuntimeError(f'No differentiable input to "{name}" is a differentiable Tensor even though a '
756-
'forward gradient formula has been defined for it. This case should only happen '
757-
'for function that take a single TensorList as input. All other cases are not '
758-
'supported right now.')
755+
raise RuntimeError(f'No differentiable input to "{name}" is a differentiable Tensor (as the provided'
756+
'forward AD formula does not use any input tangent) even though a forward gradient '
757+
'formula has been defined for it. This case should only happen for function that '
758+
'take a single TensorList as input. All other cases are not supported right now.')
759759
requires_fw_grad = "true"
760760
unpacked_arguments = ""
761761
for inp in differentiable_inputs:

tools/autograd/load_derivatives.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,21 +170,23 @@ def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
170170
"forward definition of gradient as element_wise but it does not "
171171
"defines the gradient formula for its argument which is required.")
172172
# This transformation is based on the observation that for element-wise functions, the Jacobian
173-
# matrix is diagonal and thus doing J * v or v * J gives the same result.
173+
# matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions)
174+
# For the complex case, we use hermitian transpose and get (v.conj() J).conj()
174175
# So here we are going to re-use the backward formula and replace two things:
175-
# 1) all occurrences of "grad" with "foo_t", where foo is the name of the unique differentiable input.
176+
# 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input.
176177
# 2) all usage of an original input "foo" with its primal value "foo_p".
178+
# 3) conjugate the final result
177179
# For example, for abs, the backward formula is:
178180
# grad * self.sgn()
179181
# And this function generates a forward formula that is:
180-
# self_t * self_p.sgn()
182+
# (self_t.conj() * self_p.sgn()).conj()
181183

182184
backward_formula = derivatives[0].original_formula
183185
input_name = args_with_derivatives[0].name
184186

185187
# Do replacement 1) of the grad
186188
def repl(m: Any) -> str:
187-
return f"{m.group(1)}{input_name}_t{m.group(2)}"
189+
return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}"
188190
fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula)
189191

190192
# Do replacement 2) of the input variables
@@ -195,6 +197,9 @@ def repl(m: Any) -> str:
195197
return f"{m.group(1)}{arg_name}_p{m.group(2)}"
196198
fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula)
197199

200+
# Do the final conjugate 3)
201+
fw_formula = f"({fw_formula}).conj()"
202+
198203
# Since there is a single differentiable inputs and we necessarily need its tangent we can
199204
# simply require all differentiable input's tangent.
200205
required_inputs_tangent = tuple(all_arg_names)

torch/testing/_internal/common_methods_invocations.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3888,6 +3888,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
38883888
# "rsqrt_cpu" not implemented for 'BFloat16'
38893889
backward_dtypesIfCPU=all_types_and_complex_and(torch.bool),
38903890
assert_autodiffed=True,
3891+
supports_forward_ad=True,
38913892
decorators=(precisionOverride({torch.float16: 1e-2,
38923893
torch.bfloat16: 1e-1,
38933894
torch.complex64: 1e-2}),),
@@ -3916,6 +3917,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
39163917
safe_casts_outputs=True,
39173918
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
39183919
supports_inplace_autograd=False,
3920+
supports_forward_ad=True,
39193921
skips=(
39203922
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal',
39213923
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
@@ -3966,6 +3968,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
39663968
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
39673969
assert_autodiffed=True,
39683970
supports_inplace_autograd=False,
3971+
supports_forward_ad=True,
39693972
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
39703973
sample_inputs_func=sample_inputs_addmm),
39713974
OpInfo('addmm',
@@ -3977,6 +3980,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
39773980
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
39783981
assert_autodiffed=True,
39793982
supports_inplace_autograd=False,
3983+
supports_forward_ad=True,
39803984
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
39813985
autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
39823986
sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1)),
@@ -3987,6 +3991,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
39873991
*[torch.bfloat16] if CUDA11OrLater else []),
39883992
dtypesIfROCM=floating_types_and(torch.half),
39893993
supports_inplace_autograd=False,
3994+
supports_forward_ad=True,
39903995
skips=(
39913996
# issue may fix: https://github.com/pytorch/pytorch/issues/55589
39923997
# AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
@@ -4000,6 +4005,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
40004005
dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
40014006
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
40024007
dtypesIfROCM=floating_types_and(torch.half),
4008+
supports_forward_ad=True,
40034009
skips=(
40044010
# addbmm does not correctly warn when resizing out= inputs
40054011
SkipInfo('TestCommon', 'test_out'),
@@ -4065,6 +4071,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
40654071
backward_dtypesIfCUDA=all_types_and_complex_and(torch.bool),
40664072
# Reference: https://github.com/pytorch/pytorch/issues/50747
40674073
supports_inplace_autograd=False,
4074+
supports_forward_ad=True,
40684075
skips=(
40694076
# Reference: https://github.com/pytorch/pytorch/issues/50747
40704077
SkipInfo('TestCommon', 'test_variant_consistency_eager',
@@ -4075,6 +4082,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
40754082
dtypes=all_types_and_complex(),
40764083
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
40774084
assert_autodiffed=True,
4085+
supports_forward_ad=True,
40784086
supports_inplace_autograd=False,
40794087
skips=(
40804088
# TODO: update sample inputs with for_inplace_variant kwarg to support this test
@@ -4084,6 +4092,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
40844092
dtypes=floating_and_complex_types(),
40854093
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
40864094
supports_inplace_autograd=False,
4095+
supports_forward_ad=True,
40874096
skips=(
40884097
# TODO: update sample inputs with for_inplace_variant kwarg to support this test
40894098
SkipInfo('TestCommon', 'test_variant_consistency_eager'),),
@@ -4107,6 +4116,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
41074116
ref=np.arcsin,
41084117
domain=(-1, 1),
41094118
supports_sparse=True,
4119+
supports_forward_ad=True,
41104120
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
41114121
safe_casts_outputs=True,
41124122
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
@@ -4137,6 +4147,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
41374147
safe_casts_outputs=True,
41384148
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
41394149
supports_inplace_autograd=False,
4150+
supports_forward_ad=True,
41404151
skips=(
41414152
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal',
41424153
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
@@ -4150,13 +4161,18 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
41504161
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
41514162
device_type='cuda', dtypes=[torch.cdouble],
41524163
active_if=IS_WINDOWS),
4164+
# Complex gradcheck tests asinh at points 0 + ix for x > 1 which are points
4165+
# where asinh is not differentiable
4166+
SkipInfo('TestGradients', 'test_forward_mode_AD',
4167+
dtypes=complex_types())
41534168
)),
41544169
UnaryUfuncInfo('atan',
41554170
aliases=('arctan', ),
41564171
ref=np.arctan,
41574172
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
41584173
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
41594174
assert_autodiffed=True,
4175+
supports_forward_ad=True,
41604176
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
41614177
safe_casts_outputs=True,
41624178
skips=(
@@ -4191,6 +4207,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
41914207
safe_casts_outputs=True,
41924208
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
41934209
supports_inplace_autograd=False,
4210+
supports_forward_ad=True,
41944211
skips=(
41954212
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal',
41964213
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
@@ -4309,6 +4326,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
43094326
ref=np.conj,
43104327
dtypes=all_types_and_complex_and(torch.bool,
43114328
torch.bfloat16, torch.half),
4329+
supports_forward_ad=True,
43124330
skips=(
43134331
# File "test_unary_ufuncs.py", line 289, in test_reference_numerics
43144332
# if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype):
@@ -5750,6 +5768,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
57505768
decorators=(precisionOverride({torch.float16: 1e-2,
57515769
torch.bfloat16: 1e-2}),),
57525770
safe_casts_outputs=True,
5771+
supports_forward_ad=True,
57535772
supports_complex_to_float=True),
57545773
OpInfo('linalg.solve',
57555774
aten_name='linalg_solve',

0 commit comments

Comments
 (0)