Skip to content

Add torch.linalg.cholesky_ex without checking for errors by default#56724

Closed
IvanYashchuk wants to merge 23 commits intopytorch:masterfrom
IvanYashchuk:cholesky-check-errors
Closed

Add torch.linalg.cholesky_ex without checking for errors by default#56724
IvanYashchuk wants to merge 23 commits intopytorch:masterfrom
IvanYashchuk:cholesky-check-errors

Conversation

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk commented Apr 22, 2021

The new function has the following signature cholesky_ex(Tensor input, *, bool check_errors=False) -> (Tensor L, Tensor infos). When check_errors=True, an error is thrown if the decomposition fails; check_errors=False - responsibility for checking the decomposition is on the user.

When check_errors=False, we don't have host-device memory transfers for checking the values of the info tensor.

Rewrote the internal code for torch.linalg.cholesky. Added cholesky_stub dispatch. linalg_cholesky is implemented using calls to linalg_cholesky_ex now.

Resolves #57032.

Ref. #34272, #47608, #47953

Rewrote the internal code for `torch.linalg.cholesky`.
Added `cholesky_stub` dispatch. `linalg_cholesky` is implemented to call `linalg_cholesky_ex`.
`torch.linalg.cholesky_ex` suppresses checking of LAPACK error codes by
default and returns them.
That puts the responsibility to check the error codes on the user.
@IvanYashchuk IvanYashchuk added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label Apr 22, 2021
@IvanYashchuk IvanYashchuk requested a review from mruberry April 22, 2021 19:53
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Apr 22, 2021

💊 CI failures summary and remediations

As of commit c73bf24 (more details on the Dr. CI page):



🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

AssertionError: False is not true : Scalars fai...ith rtol=1.3e-06 and atol=1e-05 is only 1.4278052!
======================================================================
FAIL [1.536s]: test_cudnn_multiple_threads_same_device (__main__.TestCuda)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 439, in wrapper
    fn(*args, **kwargs)
  File "test_cuda.py", line 2505, in test_cudnn_multiple_threads_same_device
    (2048 - test_iters) * (2048 - test_iters))
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 1397, in assertEqual
    super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
AssertionError: False is not true : Scalars failed to compare as equal! Comparing 1666681.0 and 1098304 gives a difference of 568377.0, but the allowed difference with rtol=1.3e-06 and atol=1e-05 is only 1.4278052!

----------------------------------------------------------------------
Ran 160 tests in 77.444s

FAILED (failures=1, skipped=68)

Generating XML reports...
Generated XML report: test-reports\dist-gloo\test_cuda\TEST-TestCuda-20210501103252.xml
Generated XML report: test-reports\dist-gloo\test_cuda\TEST-TestCudaComm-20210501103252.xml
Traceback (most recent call last):

❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_linux_xenial_py3_clang5_asan_test2 (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️

May 01 13:49:30 unknown file: Failure
May 01 13:49:30 [       OK ] Kernel.Softmax2D (26 ms)
May 01 13:49:30 [ RUN      ] Kernel.Softmax3D
May 01 13:49:30 [       OK ] Kernel.Softmax3D (139 ms)
May 01 13:49:30 [ RUN      ] Kernel.Softmax4D
May 01 13:49:30 [       OK ] Kernel.Softmax4D (195 ms)
May 01 13:49:30 [ RUN      ] Kernel.ConstantTensors
May 01 13:49:30 [       OK ] Kernel.ConstantTensors (22 ms)
May 01 13:49:30 [ RUN      ] Kernel.ConstantTensorsNonContiguous
May 01 13:49:30 [       OK ] Kernel.ConstantTensorsNonContiguous (20 ms)
May 01 13:49:30 [ RUN      ] Kernel.RunFast
May 01 13:49:30 unknown file: Failure
May 01 13:49:30 C++ exception with description "SimpleIREvaluator::call_raw is not implemented yet" thrown in the test body.
May 01 13:49:30 [  FAILED  ] Kernel.RunFast (5 ms)
May 01 13:49:30 [----------] 14 tests from Kernel (622 ms total)
May 01 13:49:30 
May 01 13:49:30 [----------] 140 tests from LoopNest
May 01 13:49:30 [ RUN      ] LoopNest.ExprSimple01
May 01 13:49:30 [       OK ] LoopNest.ExprSimple01 (1 ms)
May 01 13:49:30 [ RUN      ] LoopNest.ExprLower01
May 01 13:49:30 [       OK ] LoopNest.ExprLower01 (0 ms)
May 01 13:49:30 [ RUN      ] LoopNest.ExprSimple02

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mruberry
Copy link
Copy Markdown
Collaborator

Docs failure is real:

Apr 22 21:30:47 Experimental Functions
Apr 22 21:30:47 ---------
Apr 22 21:30:47 /var/lib/jenkins/workspace/docs/source/linalg.rst:47: WARNING: Title underline too short.

named tuple list needs an update (so many lists is a pain, sorry):

test_native_functions_yaml - TestNamedTupleAPI

Traceback (most recent call last):
  File "test_namedtuple_return_api.py", line 46, in test_native_functions_yaml
    'only allowlisted operators are allowed to have named return type, got ' + name)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1371, in assertEqual
    super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
AssertionError: False is not true : Scalars failed to compare as equal! Comparing 2 and 1 gives a difference of 1, but the allowed difference with rtol=0 and atol=0 is only 0!
only allowlisted operators are allowed to have named return type, got linalg_cholesky_ex

@mruberry
Copy link
Copy Markdown
Collaborator

cc @Balandat, @ngimel, @xwang233

linalg_cholesky_out_info(input, L, info);
}

if (check_errors) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if someone calls torch.linalg.cholesky_ex(..., check_errors=True) and there is an error then the error message would say torch.linalg.cholesky...?

That's not a big deal but it'd probably be easy to pass the calling function's string here?


std::tuple<Tensor&, Tensor&> linalg_cholesky_ex_out(const Tensor& input, bool check_errors, Tensor& L, Tensor& info) {
squareCheckInputs(input);
checkSameDevice("torch.linalg.cholesky_ex", L, input, "L");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because cholesky is implemented as a call to cholesky_ex this would show the name cholesky_ex if a call to cholesky hit an error, right? Also seem comment below for when a call to cholesky_ex would show the name cholesky

Error messages aren't the end of the world but do you think we should bother taking a string for the function name? Alternatively we could just always use torch.linalg.cholesky, even when someone calls torch.linalg.cholesky_ex

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we could just use torch.linalg.cholesky.
Here torch.linalg.cholesky has its own same checks differing in the function name and that "result" is used instead of "L". It duplicates the check, but that shouldn't have overhead.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the error messages of torch.linalg.cholesky are not changed in this PR. The tests would need to be modified if the error messages would change.

CompositeExplicitAutograd: linalg_cholesky_ex

- func: linalg_cholesky.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
- func: linalg_cholesky_ex.L(Tensor self, *, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.out?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jumping ahead of @ezyang here: could we look at making these structured? If that seems challenging we could look at doing it in a follow-up PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a CPU, CUDA dispatch key here so this might be an alias style kernel that may not work with structured. You should add this to the binder of cases to solve with the transpiler.

Copy link
Copy Markdown
Collaborator Author

@IvanYashchuk IvanYashchuk Apr 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is CPU, CUDA dispatch key preferred over CompositeExplicitAutograd if the dispatch stubs are used?
Found the answer in the aten/native/README.md:

Note: kernels which call
DispatchStub should NOT be registered as CompositeExplicitAutograd, as
DispatchStub only works for CPU, CUDA

I haven't looked at porting to structured in detail, so I can't estimate how challenging is it. I was planning to port linalg module functions to structured after the 1.9 branch cut.

Comment thread test/test_linalg.py Outdated
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_cholesky_with_info(self, device, dtype):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_cholesky_ex? That will make it easier for people to find the test(s) associated with the operator

Comment thread test/test_linalg.py
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

def run_test(n, batch):
A = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
Copy link
Copy Markdown
Collaborator

@mruberry mruberry Apr 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about these tests because we already have a lot of tests for cholesky, and cholesky_ex and cholesky are "near aliases" of each other; should we just refactor the cholesky tests so that test_cholesky_foo and test_cholesky_ex_foo both call a common helper, but pass a different torch function? Then there can be separate tests that validate their divergent behavior on infos

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely agree. I think it would take a lot of energy from me to refactor the tests now. If those functions were aliases the task of course would be easy. I'd prefer to do it in a follow-up PR, after 1.9 branch cut, if that's okay.

Comment thread torch/linalg/__init__.py
True
""")

cholesky_ex = _add_docstr(_linalg.linalg_cholesky_ex, r"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that we're documenting this completely separately from linalg.cholesky because it's an experimental function

Copy link
Copy Markdown
Collaborator

@mruberry mruberry Apr 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, we should probably notify the reader early that this function is "experimental" and that it may change in a future PyTorch release

I realize it's already in an "experimental" section, but it's common to search functions directly in the docs

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a warning, is it okay?

.. warning:: This function is "experimental" and it may change in a future PyTorch release.

Not sure on the use of quotes though.

Comment thread test/test_linalg.py Outdated
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_cholesky_with_info_non_pd(self, device, dtype):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is very cool

Comment thread docs/source/linalg.rst
lstsq
householder_product

Experimental Functions
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

@mruberry
Copy link
Copy Markdown
Collaborator

Hey @IvanYashchuk! This is awesome!

I made a few comments about error messages, a docs tweak, a possible test refactor, and using structure kernels. I don't think any of them are blocking, we just need to focus on fixing the tests and any additional tweaks are a bonus. We can pursue the docs as part of our docs review, for example.

@lezcano would you like to sanity check the impl, too?

Let's also give @ezyang an opportunity to look at the native_functions.yaml changes

@Balandat
Copy link
Copy Markdown
Contributor

This is great @IvanYashchuk. Let me see if using this and checking the info resolves glacial speed issues reported in #34272 (comment)?

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@Balandat has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

DEFINE_DISPATCH(cholesky_stub);

void linalg_cholesky_out_info(const Tensor& input, const Tensor& result, const Tensor& info) {
TORCH_INTERNAL_ASSERT(input.dim() >= 2);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in the other PR, consider making these TORHC_INTERNAL_ASSERT_DEBUG_ONLY

checkLinalgCompatibleDtype("torch.linalg.cholesky_ex", L, input, "L");
checkSameDevice("torch.linalg.cholesky_ex", info, input, "info");
ScalarType info_output_type = ScalarType::Int;
checkLinalgCompatibleDtype("torch.linalg.cholesky_ex", info.scalar_type(), info_output_type);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could require that infos are of correct type and error out if they aren't, that would slightly simplify the code.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think info should follow the same behavior that is expected for all other "out" tensors. It should be possible to provide a tensor with different castable dtype, tensors with zero elements are resized, etc. Is that right?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. We are limiting type promotion behavior to pointwise functions and reductions #56356. Existing additional functionality is grandfathered in, but we should avoid adding new. So, linalg functions casting to output exists today, but info is a new addition, and should not be promoted.

// if upper=true we need to tranpose the self tensor
if (upper) {
// self.transpose_(-2, -1);
self_data = self.transpose(-2, -1).data_ptr<scalar_t>();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a dangling data pointer, or rather, this would be a dangling pointer if it was not guaranteed to be the same as self.data_ptr(), which brings a question of what does this conditional achieve?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at it now, this portion of the code is wrong and I am removing it. Thank you!
The original intention was trying to solve #57032. This was wrong and the issue is resolved now with changes to the cholesky_helper_magma function.

@@ -1659,65 +1672,89 @@ AT_ERROR("cholesky: MAGMA library not found in "
// which concisely is equal to batch_size % batch_limit
if (batch_size % batch_limit != 0) {
magmaCholeskyBatched<scalar_t>(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but consider bringing this to the upper loop, the only change needed is mini_idx < batch_size in the loop statement, and nbatches = std::min(batch_limit, batch_size - mini_idx) as argument to magmaCholeskyBatched.

Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) {
auto info_shape = IntArrayRef(
self.sizes().cbegin(), self.sizes().cend() - 2); // self.shape[:-2]
Tensor self_working_copy = cloneBatchedColumnMajor(self);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please leave a comment here that clone is called because legacy cholesky doesn't do clone before?

@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 23, 2021
@Balandat
Copy link
Copy Markdown
Contributor

Balandat commented Apr 23, 2021

This is great. In this notebook I made a quick and dirty update of the psd_safe_cholesky helper we're using in gpytorch that automatically adds jitter upon failure of cholesky: test_cholesky_ex.ipynb.zip

For a small matrix example this gets the runtime down from 123ms -> 76µs. Yes, those are the correct units, this is a speedup of ~2,000X.

cc @jacobrgardner, @gpleiss, @dme65, @vishwakftw

@IvanYashchuk IvanYashchuk force-pushed the cholesky-check-errors branch from 454f89b to a075e78 Compare April 28, 2021 10:49
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
static void apply_cholesky(Tensor& self, bool upper, std::vector<int64_t>& infos) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is moved to BatchLinearAlgebraKernel.cpp file.

}
squareCheckInputs(self);

auto raw_cholesky_output = at::_cholesky_helper(self, upper);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed at::_cholesky_helper, it is replaced with cholesky_stub.

@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

@mruberry, I updated this pull request. Most of the changes in this PR are rewrites of internal code for cholesky needed to expose the info tensor and using dispatch stubs instead of aten::helper.

torch.cholesky, torch.linalg.cholesky, torch.linalg.cholesky_ex have their own error checks so that the user-facing function name is used.
I removed at::_cholesky_helper, it is replaced with cholesky_stub.
Fixed the bug in torch.cholesky (#57032) and added a test for it.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@Balandat has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Balandat added a commit to cornellius-gp/gpytorch that referenced this pull request Apr 30, 2021
This can *significantly* speed up `psd_safe_cholesky` due to cutting out the pytorch error-handling middle man, achieving ~2,0000X speedups: pytorch/pytorch#56724 (comment)
This also allows us to add jitter to the specific batch elements for which the decomp failed (rather than idiscriminatly to all).

This requires pytorch/pytorch#56724 that hasn't landed yet but will be part of 1.9. Either way, I implemented this in a backward-compatible fashion so this will work with older pytorch versions as well.
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel and I had a chance to take a look, and this is awesome!

It just needs to be rebased.

@xwang233, what's the status on cusolverDnXpotrfBatched, by the way?

@xwang233
Copy link
Copy Markdown
Collaborator

@mruberry For cusolverDnXpotrfBatched, it has been fixed internally and will be released soon in the upcoming cuda releases! I'll keep you guys updated when it's ready. Thanks.

@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

@mruberry, thanks!
I resolved merge conflicts.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@Balandat has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in 75a2a92.

gpleiss pushed a commit to cornellius-gp/gpytorch that referenced this pull request May 3, 2021
This can *significantly* speed up `psd_safe_cholesky` due to cutting out the pytorch error-handling middle man, achieving ~2,0000X speedups: pytorch/pytorch#56724 (comment)
This also allows us to add jitter to the specific batch elements for which the decomp failed (rather than idiscriminatly to all).

This requires pytorch/pytorch#56724 that hasn't landed yet but will be part of 1.9. Either way, I implemented this in a backward-compatible fashion so this will work with older pytorch versions as well.
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
…ytorch#56724)

Summary:
The new function has the following signature `cholesky_ex(Tensor input, *, bool check_errors=False) -> (Tensor L, Tensor infos)`. When `check_errors=True`, an error is thrown if the decomposition fails; `check_errors=False` - responsibility for checking the decomposition is on the user.

When `check_errors=False`, we don't have host-device memory transfers for checking the values of the `info` tensor.

Rewrote the internal code for `torch.linalg.cholesky`. Added `cholesky_stub` dispatch. `linalg_cholesky` is implemented using calls to `linalg_cholesky_ex` now.

Resolves pytorch#57032.

Ref. pytorch#34272, pytorch#47608, pytorch#47953

Pull Request resolved: pytorch#56724

Reviewed By: ngimel

Differential Revision: D27960176

Pulled By: mruberry

fbshipit-source-id: f05f3d5d9b4aa444e41c4eec48ad9a9b6fd5dfa5
gpleiss pushed a commit to cornellius-gp/linear_operator that referenced this pull request May 23, 2022
This can *significantly* speed up `psd_safe_cholesky` due to cutting out the pytorch error-handling middle man, achieving ~2,0000X speedups: pytorch/pytorch#56724 (comment)
This also allows us to add jitter to the specific batch elements for which the decomp failed (rather than idiscriminatly to all).

This requires pytorch/pytorch#56724 that hasn't landed yet but will be part of 1.9. Either way, I implemented this in a backward-compatible fashion so this will work with older pytorch versions as well.
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
…ytorch#56724)

Summary:
The new function has the following signature `cholesky_ex(Tensor input, *, bool check_errors=False) -> (Tensor L, Tensor infos)`. When `check_errors=True`, an error is thrown if the decomposition fails; `check_errors=False` - responsibility for checking the decomposition is on the user.

When `check_errors=False`, we don't have host-device memory transfers for checking the values of the `info` tensor.

Rewrote the internal code for `torch.linalg.cholesky`. Added `cholesky_stub` dispatch. `linalg_cholesky` is implemented using calls to `linalg_cholesky_ex` now.

Resolves pytorch#57032.

Ref. pytorch#34272, pytorch#47608, pytorch#47953

Pull Request resolved: pytorch#56724

Reviewed By: ngimel

Differential Revision: D27960176

Pulled By: mruberry

fbshipit-source-id: f05f3d5d9b4aa444e41c4eec48ad9a9b6fd5dfa5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.cholesky with upper=True is wrong for batched CUDA inputs