Use torch.linalg.cholesky_ex instead of torch.cholesky#1586
Conversation
This can *significantly* speed up `psd_safe_cholesky` due to cutting out the pytorch error-handling middle man, achieving ~2,0000X speedups: pytorch/pytorch#56724 (comment) This also allows us to add jitter to the specific batch elements for which the decomp failed (rather than idiscriminatly to all). This requires pytorch/pytorch#56724 that hasn't landed yet but will be part of 1.9. Either way, I implemented this in a backward-compatible fashion so this will work with older pytorch versions as well.
332b4f8 to
b434af6
Compare
| L = _psd_safe_cholesky(A, out=out, jitter=jitter, max_tries=max_tries) | ||
| if upper: | ||
| if out is not None: | ||
| out = out.transpose_(-1, -2) |
There was a problem hiding this comment.
I guess the assignment is kind of pointless but oh well
| :attr:`max_tries` (int, optional): | ||
| Number of attempts (with successively increasing jitter) to make before raising an error. | ||
| """ | ||
| L = _psd_safe_cholesky(A, out=out, jitter=jitter, max_tries=max_tries) |
There was a problem hiding this comment.
You could pass upper here to avoid the transposing below.
There was a problem hiding this comment.
so the reason I'm doing it this way is that cholesky_ex does not support the upper arg, so I wouldn't be able to have _psd_safe_cholesky have the same signature for both cholesky_ex and cholesky internals. This ensures that _psd_safe_cholesky is functionally equivalent whether or not cholesky_ex is available or not
There was a problem hiding this comment.
Oh I thought cholesky_ex supported the upper arg. Strange...
|
@jacobrgardner can we get this in? |
|
hmm looks like I need to fix a number of mock call counts in a backward compatible way. |
This can significantly speed up
psd_safe_choleskydue to cutting out the pytorch error-handling middle man, achieving ~2,000 X reduction in wall time: pytorch/pytorch#56724 (comment).This also allows us to add jitter to the specific batch elements for which the decomp failed (rather than indiscriminately to all).
This requires pytorch/pytorch#56724 that hasn't landed yet but will be part of 1.9. Either way, I implemented this in a backward-compatible fashion so this will work with older pytorch versions as well.