@@ -2386,6 +2386,10 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(const Tensor& dA,
23862386 // dX = dP - dS
23872387 // E_{jk} = S_k^2 - S_j^2 if j != k
23882388 // 1 otherwise
2389+
2390+ // Checks compute_uv=true
2391+ TORCH_INTERNAL_ASSERT (U_.dim () >= 2 && Vh_.dim () >= 2 );
2392+
23892393 const auto is_complex = dA.is_complex ();
23902394 const auto m = dA.size (-2 );
23912395 const auto n = dA.size (-1 );
@@ -2396,53 +2400,47 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(const Tensor& dA,
23962400 const auto V = Vh.mH ();
23972401
23982402 // dP = U^H dA V
2399- const auto dP = m >= n ? at::matmul (U.mH (), at::matmul (dA, V))
2400- : at::matmul (at::matmul (U.mH (), dA), V);
2403+ auto dP = m >= n ? at::matmul (U.mH (), at::matmul (dA, V))
2404+ : at::matmul (at::matmul (U.mH (), dA), V);
24012405
2402- // We don't want dS to be a view into a larger tensor so we clone it
2403- auto dS = is_complex ? at::real (dP.diagonal (0 , -2 , -1 )).clone ()
2404- : dP.diagonal (0 , -2 , -1 ).clone ();
2406+ auto dS = is_complex ? at::real (dP.diagonal (0 , -2 , -1 ))
2407+ : dP.diagonal (0 , -2 , -1 );
24052408
24062409 // dX = dP - dS
2407- // Here we update dP in-place rather than
2408- if (is_complex) {
2409- at::real (dP.diagonal (0 , -2 , -1 )).zero_ ();
2410- } else {
2411- dP.diagonal (0 , -2 , -1 ).zero_ ();
2412- }
2410+ dP = dP - dS.diag_embed ();
24132411
24142412 auto E = [&S]{
24152413 const auto S2 = S * S;
24162414 auto ret = S2.unsqueeze (-2 ) - S2.unsqueeze (-1 );
2417- // Any number a != 0 would, as we are going to compute 0 / a
2415+ // Any number a != 0 would, as we are just going to use it to compute 0 / a later on
24182416 ret.diagonal (0 , -2 , -1 ).fill_ (1 );
24192417 return ret;
24202418 }();
24212419
24222420 const auto sym = [](const Tensor& X) { return X + X.mH (); };
24232421
2424- // Im( diag(dP) ) / (2S)
2425- auto ImdiagdP2S = is_complex ? at::imag ( dP.diagonal (0 , -2 , -1 ) ).div (2 . * S)
2426- : Tensor{};
2422+ // diag(dP) / (2S)
2423+ auto diagdP2S = is_complex ? dP.diagonal (0 , -2 , -1 ).div (2 . * S)
2424+ : Tensor{};
24272425
24282426 // dU = U (sym(dP S) / E) + i Im(diag(dP)) / (2S)
24292427 auto dU = [&] {
24302428 auto dUaux = sym (dP * S.unsqueeze (-2 )) / E;
24312429 if (is_complex) {
2432- at::imag ( dUaux. diagonal ( 0 , - 2 , - 1 )). copy_ (ImdiagdP2S );
2430+ dUaux = dUaux + diagdP2S. diag_embed ( );
24332431 }
24342432 return at::matmul (U, dUaux);
24352433 }();
24362434 if (m > n) {
24372435 // dU += (I_m - UU^H) dA V S^{-1}
24382436 const auto dAVSinv = at::matmul (dA, V / S.unsqueeze (-2 )) ;
2439- dU += dAVSinv - at::matmul (U, at::matmul (U.mH (), dAVSinv));
2437+ dU = dU + dAVSinv - at::matmul (U, at::matmul (U.mH (), dAVSinv));
24402438
24412439 // To "fix" the full_matrices case (the full_matrices case should not be differentiable...)
24422440 if (full_matrices) {
24432441 auto shape = dU.sizes ().vec ();
24442442 shape.end ()[-1 ] = m - n;
2445- dU = at::cat ({dU, at::zeros (shape, dU.options () )}, /* dim=*/ -1 );
2443+ dU = at::cat ({dU, dU.new_zeros (shape )}, /* dim=*/ -1 );
24462444 }
24472445 }
24482446
@@ -2451,20 +2449,20 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(const Tensor& dA,
24512449 auto dVh = [&] {
24522450 auto dVhaux = sym (dP * (-S).unsqueeze (-1 )) / E;
24532451 if (is_complex) {
2454- at::imag ( dVhaux. diagonal ( 0 , - 2 , - 1 )). copy_ (ImdiagdP2S );
2452+ dVhaux = dVhaux + diagdP2S. diag_embed ( );
24552453 }
24562454 return at::matmul (dVhaux, Vh);
24572455 }();
24582456 if (m < n) {
24592457 // dVh += S^{-1} U^H dA (I_n - VV^H)
24602458 const auto UHdASinv = at::matmul (U.mH () / S.unsqueeze (-1 ), dA) ;
2461- dVh += UHdASinv - at::matmul (at::matmul (UHdASinv, V), Vh);
2459+ dVh = dVh + UHdASinv - at::matmul (at::matmul (UHdASinv, V), Vh);
24622460
24632461 // To "fix" the full_matrices case (the full_matrices case should not be differentiable...)
24642462 if (full_matrices) {
24652463 auto shape = dVh.sizes ().vec ();
24662464 shape.end ()[-2 ] = n - m;
2467- dVh = at::cat ({dVh, at::zeros (shape, dVh.options () )}, /* dim=*/ -2 );
2465+ dVh = at::cat ({dVh, dVh.new_zeros (shape )}, /* dim=*/ -2 );
24682466 }
24692467 }
24702468
@@ -2589,6 +2587,9 @@ Tensor svd_backward(const Tensor& gU,
25892587 // where we have used that Im(diag(U^H gU)) = - Im(diag(V^h gV)) to group the diagonal imaginary terms into one
25902588 // that just depends on U^H gU.
25912589
2590+ // Checks compute_uv=true
2591+ TORCH_INTERNAL_ASSERT (U.dim () >= 2 && Vh.dim () >= 2 );
2592+
25922593 // Trivial case
25932594 if (!gS .defined () && !gU .defined () && !gVh .defined ()) {
25942595 return {};
@@ -2605,8 +2606,9 @@ Tensor svd_backward(const Tensor& gU,
26052606 // At this point, at least one of gU, gVh is defined
26062607
26072608 const bool is_complex = U.is_complex ();
2608- const auto UhgU = gU .defined () ? at::matmul (U.mH (), gU ) : Tensor{};
2609- const auto VhgV = gVh .defined () ? at::matmul (Vh, gVh .mH ()) : Tensor{};
2609+ const auto skew = [](const Tensor& A) { return A - A.mH (); };
2610+ const auto UhgU = gU .defined () ? skew (at::matmul (U.mH (), gU )) : Tensor{};
2611+ const auto VhgV = gVh .defined () ? skew (at::matmul (Vh, gVh .mH ())) : Tensor{};
26102612
26112613 // Check for the invariance of the loss function, i.e.
26122614 // Im(diag(U^H gU)) + Im(diag(V^H gV)) = 0
@@ -2622,12 +2624,10 @@ Tensor svd_backward(const Tensor& gU,
26222624 " it ill-defined." );
26232625 }
26242626
2625- // gA = (skew (U^H gU) / E) S + S ((skew (V^H gV) / E) + I o (gS + i Im( diag(U^H gU)) / S )
2627+ // gA = ((U^H gU) / E) S + S (((V^H gV) / E) + I o (gS + diag(U^H gU) / (2 * S) )
26262628 Tensor gA = [&] {
26272629 // ret holds everything but the diagonal of gA
26282630 auto ret = [&] {
2629- const auto skew = [](const Tensor& A) { return A - A.mH (); };
2630-
26312631 const auto E = [&S]{
26322632 const auto S2 = S * S;
26332633 auto ret = S2.unsqueeze (-2 ) - S2.unsqueeze (-1 );
@@ -2638,31 +2638,20 @@ Tensor svd_backward(const Tensor& gU,
26382638
26392639 if (gU .defined ()) {
26402640 if (gVh .defined ()) {
2641- return (skew ( UhgU) * S.unsqueeze (-2 ) + S.unsqueeze (-1 ) * skew ( VhgV) ) / E;
2641+ return (UhgU * S.unsqueeze (-2 ) + S.unsqueeze (-1 ) * VhgV) / E;
26422642 } else {
2643- return (skew ( UhgU) / E) * S.unsqueeze (-2 );
2643+ return (UhgU / E) * S.unsqueeze (-2 );
26442644 }
26452645 } else { // gVh.defined();
2646- return S.unsqueeze (-1 ) * (skew ( VhgV) / E);
2646+ return S.unsqueeze (-1 ) * (VhgV / E);
26472647 }
26482648 }();
26492649 // Fill the diagonal
2650- const auto diag = ret.diagonal (0 , -2 , -1 );
2650+ if (gS .defined ()) {
2651+ ret = ret + gS .diag_embed ();
2652+ }
26512653 if (is_complex && gU .defined () && gVh .defined ()) {
2652- if (gS .defined ()) {
2653- at::real (diag).copy_ (gS );
2654- } else {
2655- // Not strictly necessary, but we do it for good measure
2656- at::real (diag).zero_ ();
2657- }
2658- at::imag (diag).copy_ (at::imag (UhgU.diagonal (0 , -2 , -1 )) / S);
2659- } else {
2660- if (gS .defined ()) {
2661- diag.copy_ (gS );
2662- } else {
2663- // Not strictly necessary, but we do it for good measure
2664- diag.zero_ ();
2665- }
2654+ ret = ret + (UhgU.diagonal (0 , -2 , -1 ) / (2 . * S)).diag_embed ();
26662655 }
26672656 return ret;
26682657 }();
0 commit comments