Skip to content

torch.cholesky with upper=True is wrong for batched CUDA inputs #57032

@IvanYashchuk

Description

@IvanYashchuk

🐛 Bug

torch.cholesky with upper=True flag returns wrong results for batched CUDA inputs.

To Reproduce

Steps to reproduce the behavior:

In [1]: import torch
In [2]: a = torch.rand(2, 3, 3)
In [3]: a = a @ a.transpose(-2, -1)
In [4]: a_triu = a.triu() # fill the lower triangular part with zero
In [5]: a_triu
Out[5]: 
tensor([[[0.2760, 0.4856, 0.1998],
         [0.0000, 0.8911, 0.2984],
         [0.0000, 0.0000, 0.2353]],

        [[1.2673, 0.9277, 0.9247],
         [0.0000, 0.9955, 0.9757],
         [0.0000, 0.0000, 1.4462]]])
In [6]: torch.cholesky(a_triu, upper=True) # expected result
Out[6]: 
tensor([[[ 0.5253,  0.9244,  0.3804],
         [ 0.0000,  0.1912, -0.2781],
         [ 0.0000,  0.0000,  0.1154]],

        [[ 1.1257,  0.8241,  0.8214],
         [ 0.0000,  0.5625,  0.5312],
         [ 0.0000,  0.0000,  0.6995]]])
In [7]: torch.cholesky(a_triu.cuda(), upper=True) # this is wrong
Out[7]: 
tensor([[[0.5253, 0.0000, 0.0000],
         [0.0000, 0.9440, 0.0000],
         [0.0000, 0.0000, 0.4851]],

        [[1.1257, 0.0000, 0.0000],
         [0.0000, 0.9977, 0.0000],
         [0.0000, 0.0000, 1.2026]]], device='cuda:0')
In [8]: torch.cholesky(a_triu.cuda()[0], upper=True) # single input works
Out[8]: 
tensor([[ 0.5253,  0.9244,  0.3804],
        [ 0.0000,  0.1912, -0.2781],
        [ 0.0000,  0.0000,  0.1154]], device='cuda:0')

Expected behavior

The result for batched CUDA inputs should be correct, the above code snippet shows the expected behavior.

cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233 @lezcano

Metadata

Metadata

Assignees

Labels

module: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: magmarelated to magma linear algebra cuda supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions