Skip to content

pinv could be differentiable on a wider range of inputs #65911

@nikitaved

Description

@nikitaved

🐛 Bug

The paper
The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. Source: SIAM Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432
states that pinv is Frechet-differentiable in a rank-preserving neighborhood.

However, given that pinv is implemented via the SVD, its implicit SVD-based backward suffers from all the same issues: no repeated singular values are allowed.

To Reproduce

Here is an example of a rank-preserving transformation that should be differentiable but it is not:

In [1]: import torch

In [2]: x = torch.rand(30, 1, dtype=torch.double, requires_grad=True)

In [3]: y = torch.rand(30, 1, dtype=torch.double, requires_grad=True)                                                                                                                                                                                                                                                       

In [4]: torch.autograd.gradcheck(lambda a, b: torch.linalg.pinv(a @ b.t()), [x, y])
---------------------------------------------------------------------------
GradcheckError                            Traceback (most recent call last)
<ipython-input-4-5e7e12923389> in <module>()
----> 1 torch.autograd.gradcheck(lambda a, b: torch.linalg.pinv(a @ b.t()), [x, y])

~/git/Quansight/pytorch/torch/autograd/gradcheck.py in gradcheck(func, inputs, eps, atol, rtol, raise_exception, check_sparse_nnz, nondet_tol, check_undefined_grad, check_grad_dtypes, check_batched_grad, check_forward_ad, check_backward_ad, fast_mode)                                                                 
   1271             return False
   1272     else:
-> 1273         return _gradcheck_helper(**args)
   1274
   1275

~/git/Quansight/pytorch/torch/autograd/gradcheck.py in _gradcheck_helper(func, inputs, eps, atol, rtol, check_sparse_nnz, nondet_tol, check_undefined_grad, check_grad_dtypes, check_batched_grad, check_forward_ad, check_backward_ad, fast_mode)                                                                          
   1286     _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps,
   1287                          rtol, atol, check_grad_dtypes, check_forward_ad=check_forward_ad,
-> 1288                          check_backward_ad=check_backward_ad, nondet_tol=nondet_tol)
   1289     # Short circuit because remaining tests rely on backward AD to be implemented
   1290     if not check_backward_ad:

~/git/Quansight/pytorch/torch/autograd/gradcheck.py in _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes, check_forward_ad, check_backward_ad, nondet_tol)                                                                                                      
    946         else:
    947             gradcheck_fn(func, func_out, tupled_inputs, outputs, eps,
--> 948                          rtol, atol, check_grad_dtypes, nondet_tol)
    949
    950     if check_forward_ad:

~/git/Quansight/pytorch/torch/autograd/gradcheck.py in _slow_gradcheck(func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes, nondet_tol, use_forward_ad, complex_indices, test_imag)                                                                                                                  
    994             for j, (a, n) in enumerate(zip(analytical, numerical[i])):
    995                 if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol):
--> 996                     raise GradcheckError(_get_notallclose_msg(a, n, i, j, complex_indices, test_imag))
    997
    998     return True

GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[ 9.9343e-04, -1.7439e-05, -5.0706e-06,  ..., -4.6110e-05,
         -9.7824e-05, -5.8797e-05],
        [-1.7439e-05,  9.1318e-04, -2.4387e-05,  ..., -2.2177e-04,
         -4.7049e-04, -2.8279e-04],
        [-5.0706e-06, -2.4387e-05,  9.8997e-04,  ..., -6.4482e-05,
         -1.3680e-04, -8.2224e-05],
        ...,
        [-1.2011e-05, -5.7770e-05, -1.6797e-05,  ...,  3.6748e-03,
         -3.2406e-04, -1.9478e-04],
        [-2.5482e-05, -1.2256e-04, -3.5635e-05,  ..., -3.2406e-04,
          3.1401e-03, -4.1322e-04],
        [-1.5316e-05, -7.3665e-05, -2.1419e-05,  ..., -1.9478e-04,
         -4.1322e-04,  3.5792e-03]], dtype=torch.float64)
analytical:tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], dtype=torch.float64)

Additional context

I still have to read the paper in depth to see whether we can fix the backward once and for all.
Still though, even without the paper we could still use the SVD in backward where we only consider the
first rank of the input subspaces. This way we still suffer from the SVD backward issue (no repeated singular values),
but at least we can handle an arbitrary rank if non-zero singular values are distinct.

Once we have explicit backward for pinv, it could be used for lstsq to make it differentiable.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @jianyuh @mruberry @walterddr @IvanYashchuk @xwang233

Metadata

Metadata

Assignees

Labels

module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions