Skip to content

Use torch.linalg.cholesky_ex instead of torch.cholesky#1586

Merged
gpleiss merged 4 commits intomasterfrom
fast_psd_safe_chol
May 3, 2021
Merged

Use torch.linalg.cholesky_ex instead of torch.cholesky#1586
gpleiss merged 4 commits intomasterfrom
fast_psd_safe_chol

Conversation

@Balandat
Copy link
Copy Markdown
Collaborator

This can significantly speed up psd_safe_cholesky due to cutting out the pytorch error-handling middle man, achieving ~2,000 X reduction in wall time: pytorch/pytorch#56724 (comment).

This also allows us to add jitter to the specific batch elements for which the decomp failed (rather than indiscriminately to all).

This requires pytorch/pytorch#56724 that hasn't landed yet but will be part of 1.9. Either way, I implemented this in a backward-compatible fashion so this will work with older pytorch versions as well.

This can *significantly* speed up `psd_safe_cholesky` due to cutting out the pytorch error-handling middle man, achieving ~2,0000X speedups: pytorch/pytorch#56724 (comment)
This also allows us to add jitter to the specific batch elements for which the decomp failed (rather than idiscriminatly to all).

This requires pytorch/pytorch#56724 that hasn't landed yet but will be part of 1.9. Either way, I implemented this in a backward-compatible fashion so this will work with older pytorch versions as well.
@Balandat Balandat force-pushed the fast_psd_safe_chol branch from 332b4f8 to b434af6 Compare April 30, 2021 03:46
L = _psd_safe_cholesky(A, out=out, jitter=jitter, max_tries=max_tries)
if upper:
if out is not None:
out = out.transpose_(-1, -2)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the assignment is kind of pointless but oh well

:attr:`max_tries` (int, optional):
Number of attempts (with successively increasing jitter) to make before raising an error.
"""
L = _psd_safe_cholesky(A, out=out, jitter=jitter, max_tries=max_tries)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could pass upper here to avoid the transposing below.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the reason I'm doing it this way is that cholesky_ex does not support the upper arg, so I wouldn't be able to have _psd_safe_cholesky have the same signature for both cholesky_ex and cholesky internals. This ensures that _psd_safe_cholesky is functionally equivalent whether or not cholesky_ex is available or not

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I thought cholesky_ex supported the upper arg. Strange...

@Balandat
Copy link
Copy Markdown
Collaborator Author

Balandat commented May 3, 2021

@jacobrgardner can we get this in?cholesky_ex has been merged into master, and it would be really nice to use it.

@Balandat
Copy link
Copy Markdown
Collaborator Author

Balandat commented May 3, 2021

hmm looks like I need to fix a number of mock call counts in a backward compatible way.

@gpleiss gpleiss merged commit 336b333 into master May 3, 2021
@gpleiss gpleiss deleted the fast_psd_safe_chol branch May 3, 2021 22:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants