Skip to content

Commit bffb67a

Browse files
lezcanozou3519
authored andcommitted
Make svd / svdvals fully functorch compatible (#72181)
Summary: This should (hopefully) make all the CI from `functorch` go green (including jvp's!) after changing `VARIADIC_BDIMS_BOXED(_svd_helper);` with `VARIADIC_BDIMS_BOXED(_linalg_svd);` and removing all the skip and xfails associated to `linalg.svdvals`. Locally, there's just one test that started failing because of this, and that is `test_vmapjvpall_norm_nuc_cpu_float32`. I have no idea what's going on here, but it's a jvp product, so not a regression, and it might very well be caused by the jvp of other operation within `norm_nuc` as this is a composite operation. Pull Request resolved: #72181 Reviewed By: ngimel Differential Revision: D33952744 Pulled By: zou3519 fbshipit-source-id: 2a2510d97eed4a0bfc25615264ddd36e38856efe (cherry picked from commit 5805fa1)
1 parent ef6c1f8 commit bffb67a

2 files changed

Lines changed: 37 additions & 45 deletions

File tree

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <ATen/native/Resize.h>
1111
#include <ATen/native/cpu/zmath.h>
1212
#include <ATen/Parallel.h>
13+
#include <ATen/TensorSubclassLikeUtils.h>
1314

1415
#include <c10/util/irange.h>
1516

@@ -3062,7 +3063,9 @@ Tensor& linalg_svdvals_out(const Tensor& A, Tensor & S) {
30623063
}
30633064

30643065
Tensor linalg_svdvals(const Tensor& A) {
3065-
const bool A_requires_grad = (at::GradMode::is_enabled() && A.requires_grad());
3066+
const bool A_requires_grad = (at::GradMode::is_enabled() && A.requires_grad())
3067+
|| A._fw_grad(/*level */ 0).defined()
3068+
|| isTensorSubclassLike(A);
30663069
return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false, /*comptue_uv=*/A_requires_grad));
30673070
}
30683071

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,6 +2386,10 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(const Tensor& dA,
23862386
// dX = dP - dS
23872387
// E_{jk} = S_k^2 - S_j^2 if j != k
23882388
// 1 otherwise
2389+
2390+
// Checks compute_uv=true
2391+
TORCH_INTERNAL_ASSERT(U_.dim() >= 2 && Vh_.dim() >= 2);
2392+
23892393
const auto is_complex = dA.is_complex();
23902394
const auto m = dA.size(-2);
23912395
const auto n = dA.size(-1);
@@ -2396,53 +2400,47 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(const Tensor& dA,
23962400
const auto V = Vh.mH();
23972401

23982402
// dP = U^H dA V
2399-
const auto dP = m >= n ? at::matmul(U.mH(), at::matmul(dA, V))
2400-
: at::matmul(at::matmul(U.mH(), dA), V);
2403+
auto dP = m >= n ? at::matmul(U.mH(), at::matmul(dA, V))
2404+
: at::matmul(at::matmul(U.mH(), dA), V);
24012405

2402-
// We don't want dS to be a view into a larger tensor so we clone it
2403-
auto dS = is_complex ? at::real(dP.diagonal(0, -2, -1)).clone()
2404-
: dP.diagonal(0, -2, -1).clone();
2406+
auto dS = is_complex ? at::real(dP.diagonal(0, -2, -1))
2407+
: dP.diagonal(0, -2, -1);
24052408

24062409
// dX = dP - dS
2407-
// Here we update dP in-place rather than
2408-
if (is_complex) {
2409-
at::real(dP.diagonal(0, -2, -1)).zero_();
2410-
} else {
2411-
dP.diagonal(0, -2, -1).zero_();
2412-
}
2410+
dP = dP - dS.diag_embed();
24132411

24142412
auto E = [&S]{
24152413
const auto S2 = S * S;
24162414
auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1);
2417-
// Any number a != 0 would, as we are going to compute 0 / a
2415+
// Any number a != 0 would, as we are just going to use it to compute 0 / a later on
24182416
ret.diagonal(0, -2, -1).fill_(1);
24192417
return ret;
24202418
}();
24212419

24222420
const auto sym = [](const Tensor& X) { return X + X.mH(); };
24232421

2424-
// Im(diag(dP)) / (2S)
2425-
auto ImdiagdP2S = is_complex ? at::imag(dP.diagonal(0, -2, -1)).div(2. * S)
2426-
: Tensor{};
2422+
// diag(dP) / (2S)
2423+
auto diagdP2S = is_complex ? dP.diagonal(0, -2, -1).div(2. * S)
2424+
: Tensor{};
24272425

24282426
// dU = U (sym(dP S) / E) + i Im(diag(dP)) / (2S)
24292427
auto dU = [&] {
24302428
auto dUaux = sym(dP * S.unsqueeze(-2)) / E;
24312429
if (is_complex) {
2432-
at::imag(dUaux.diagonal(0, -2, -1)).copy_(ImdiagdP2S);
2430+
dUaux = dUaux + diagdP2S.diag_embed();
24332431
}
24342432
return at::matmul(U, dUaux);
24352433
}();
24362434
if (m > n) {
24372435
// dU += (I_m - UU^H) dA V S^{-1}
24382436
const auto dAVSinv = at::matmul(dA, V / S.unsqueeze(-2)) ;
2439-
dU += dAVSinv - at::matmul(U, at::matmul(U.mH(), dAVSinv));
2437+
dU = dU + dAVSinv - at::matmul(U, at::matmul(U.mH(), dAVSinv));
24402438

24412439
// To "fix" the full_matrices case (the full_matrices case should not be differentiable...)
24422440
if (full_matrices) {
24432441
auto shape = dU.sizes().vec();
24442442
shape.end()[-1] = m - n;
2445-
dU = at::cat({dU, at::zeros(shape, dU.options())}, /*dim=*/-1);
2443+
dU = at::cat({dU, dU.new_zeros(shape)}, /*dim=*/-1);
24462444
}
24472445
}
24482446

@@ -2451,20 +2449,20 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(const Tensor& dA,
24512449
auto dVh = [&] {
24522450
auto dVhaux = sym(dP * (-S).unsqueeze(-1)) / E;
24532451
if (is_complex) {
2454-
at::imag(dVhaux.diagonal(0, -2, -1)).copy_(ImdiagdP2S);
2452+
dVhaux = dVhaux + diagdP2S.diag_embed();
24552453
}
24562454
return at::matmul(dVhaux, Vh);
24572455
}();
24582456
if (m < n) {
24592457
// dVh += S^{-1} U^H dA (I_n - VV^H)
24602458
const auto UHdASinv = at::matmul(U.mH() / S.unsqueeze(-1), dA) ;
2461-
dVh += UHdASinv - at::matmul(at::matmul(UHdASinv, V), Vh);
2459+
dVh = dVh + UHdASinv - at::matmul(at::matmul(UHdASinv, V), Vh);
24622460

24632461
// To "fix" the full_matrices case (the full_matrices case should not be differentiable...)
24642462
if (full_matrices) {
24652463
auto shape = dVh.sizes().vec();
24662464
shape.end()[-2] = n - m;
2467-
dVh = at::cat({dVh, at::zeros(shape, dVh.options())}, /*dim=*/-2);
2465+
dVh = at::cat({dVh, dVh.new_zeros(shape)}, /*dim=*/-2);
24682466
}
24692467
}
24702468

@@ -2589,6 +2587,9 @@ Tensor svd_backward(const Tensor& gU,
25892587
// where we have used that Im(diag(U^H gU)) = - Im(diag(V^h gV)) to group the diagonal imaginary terms into one
25902588
// that just depends on U^H gU.
25912589

2590+
// Checks compute_uv=true
2591+
TORCH_INTERNAL_ASSERT(U.dim() >= 2 && Vh.dim() >= 2);
2592+
25922593
// Trivial case
25932594
if (!gS.defined() && !gU.defined() && !gVh.defined()) {
25942595
return {};
@@ -2605,8 +2606,9 @@ Tensor svd_backward(const Tensor& gU,
26052606
// At this point, at least one of gU, gVh is defined
26062607

26072608
const bool is_complex = U.is_complex();
2608-
const auto UhgU = gU.defined() ? at::matmul(U.mH(), gU) : Tensor{};
2609-
const auto VhgV = gVh.defined() ? at::matmul(Vh, gVh.mH()) : Tensor{};
2609+
const auto skew = [](const Tensor& A) { return A - A.mH(); };
2610+
const auto UhgU = gU.defined() ? skew(at::matmul(U.mH(), gU)) : Tensor{};
2611+
const auto VhgV = gVh.defined() ? skew(at::matmul(Vh, gVh.mH())) : Tensor{};
26102612

26112613
// Check for the invariance of the loss function, i.e.
26122614
// Im(diag(U^H gU)) + Im(diag(V^H gV)) = 0
@@ -2622,12 +2624,10 @@ Tensor svd_backward(const Tensor& gU,
26222624
"it ill-defined.");
26232625
}
26242626

2625-
// gA = (skew(U^H gU) / E) S + S ((skew(V^H gV) / E) + I o (gS + i Im(diag(U^H gU)) / S)
2627+
// gA = ((U^H gU) / E) S + S (((V^H gV) / E) + I o (gS + diag(U^H gU) / (2 * S))
26262628
Tensor gA = [&] {
26272629
// ret holds everything but the diagonal of gA
26282630
auto ret = [&] {
2629-
const auto skew = [](const Tensor& A) { return A - A.mH(); };
2630-
26312631
const auto E = [&S]{
26322632
const auto S2 = S * S;
26332633
auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1);
@@ -2638,31 +2638,20 @@ Tensor svd_backward(const Tensor& gU,
26382638

26392639
if (gU.defined()) {
26402640
if (gVh.defined()) {
2641-
return (skew(UhgU) * S.unsqueeze(-2) + S.unsqueeze(-1) * skew(VhgV)) / E;
2641+
return (UhgU * S.unsqueeze(-2) + S.unsqueeze(-1) * VhgV) / E;
26422642
} else {
2643-
return (skew(UhgU) / E) * S.unsqueeze(-2);
2643+
return (UhgU / E) * S.unsqueeze(-2);
26442644
}
26452645
} else { // gVh.defined();
2646-
return S.unsqueeze(-1) * (skew(VhgV) / E);
2646+
return S.unsqueeze(-1) * (VhgV / E);
26472647
}
26482648
}();
26492649
// Fill the diagonal
2650-
const auto diag = ret.diagonal(0, -2, -1);
2650+
if (gS.defined()) {
2651+
ret = ret + gS.diag_embed();
2652+
}
26512653
if (is_complex && gU.defined() && gVh.defined()) {
2652-
if (gS.defined()) {
2653-
at::real(diag).copy_(gS);
2654-
} else {
2655-
// Not strictly necessary, but we do it for good measure
2656-
at::real(diag).zero_();
2657-
}
2658-
at::imag(diag).copy_(at::imag(UhgU.diagonal(0, -2, -1)) / S);
2659-
} else {
2660-
if (gS.defined()) {
2661-
diag.copy_(gS);
2662-
} else {
2663-
// Not strictly necessary, but we do it for good measure
2664-
diag.zero_();
2665-
}
2654+
ret = ret + (UhgU.diagonal(0, -2, -1) / (2. * S)).diag_embed();
26662655
}
26672656
return ret;
26682657
}();

0 commit comments

Comments
 (0)