Skip to content

torch.cholesky fails exceedingly slow on non-pd matrices #34272

@Balandat

Description

@Balandat

🐛 Bug

If a tensor is not positive definite, torch.cholesky takes forever to raise an error. In the example below it takes almost 100 times longer to raise an error than to actually perform the decomposition for a positive definite matrix of the same size. That happens even if the matrix has evals far from zero, or if the matrix is far from symmetric.

This is really problematic, since attempting a Cholesky decomposition is a standard way of determining whether a matrix is positive definite.

To Reproduce

import torch
from numpy import linalg as nla

n = 100
a = torch.rand(n, n)
H = a @ a.t() + torch.diag_embed(torch.rand(n))

Runtime is comparable

%timeit torch.cholesky(H)
55.5 µs ± 4.16 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit nla.cholesky(H)
59.2 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Time for catching failure is not at all

Even if the matrix has only very negative evals

n = 100
a = torch.rand(n, n)
H = a @ a.t() + torch.diag_embed(-10 + torch.rand(n))
def cholesky_fail(H):
    try:
        torch.cholesky(H)
    except:
        return H
    
%timeit cholesky_fail(H)
4.61 ms ± 206 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
def cholesky_fail_np(H):
    try:
        nla.cholesky(H)
    except:
        return H
    
np_a = a.numpy()
%timeit cholesky_fail_np(H)
32.3 µs ± 961 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Things look bleak even if the matrix is not symmetric

M = torch.rand(n, n)
%timeit cholesky_fail(M)
4.67 ms ± 169 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit cholesky_fail_np(M)
33.4 µs ± 1.77 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Expected behavior

Should fail fast.

Environment

Can reproduce this both on MacOS on the latest conda package and on FB infra.

cc @vincentqb @vishwakftw @jianyuh @nikitaved @pearu @VitalyFedyunin @ngimel @nazanint (who brought this to my attention), @vishwakftw, @gchanan

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions