Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146799
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4ea9ce5 with merge base 91c4bf3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
| out.tril_(); | ||
| upper ? out.transpose_(ndim - 2, ndim - 1) : out; |
There was a problem hiding this comment.
This will silently alter the stride structure of out if upper == true. It is better be upper ? out.triu_() : out.tril_().
There was a problem hiding this comment.
That's not the same. The kernel does decomposition in the lower part of the matrix. If you do out.triu_() instead of out.tril_ -> transpose, then you get the upper part of the matrix which isn't really the correct output.
There was a problem hiding this comment.
Do you have some stride assumptions in the kernel, or is it stride-agnostic? If it is stride-agnostic, then the kernel could be run on the transposed variant.
There was a problem hiding this comment.
It assumes that input is row major(contiguous)
There was a problem hiding this comment.
out can be provided externally as column-major. What would happen in this case?
There was a problem hiding this comment.
I printed data ptr inside the mps function and outside in python:
import torch
out = torch.rand(3, 3, 3, device="mps").permute(2, 1, 0)
x = torch.rand(3, 3, 3, device="mps")
x = x.mT @ x
data_ptr = out.data_ptr()
print(f"0x{data_ptr:x}") # lowercase hex
torch.linalg.cholesky(x, out=out)
print(f"0x{out.data_ptr():x}")
Yields:
0x10a4d68d0
0x10fb19150
0x10a4d68d0
First one being print from python, 2nd one being before launching the kernel from C++ and 3rd one being again from python. So yeah confirmed
There was a problem hiding this comment.
As per https://github.com/pytorch/pytorch/pull/146799/files#r1952464144, this is expected. Sorry for the confusion. But we should have issues when out is contiguous and upper=True it seems.
There was a problem hiding this comment.
No issues from what I check:
import torch
out = torch.rand(3, 3, 3, device="mps").permute(2, 1, 0)
x = torch.rand(3, 3, 3, device="mps")
x = x.mT @ x
data_ptr = out.data_ptr()
print(f"0x{data_ptr:x}") # lowercase hex
print(out.stride())
print(out.is_contiguous())
res1 = torch.linalg.cholesky(x, out=out, upper=True)
res2 = torch.linalg.cholesky(x.cpu(), out=out.cpu(), upper=True)
print(f"0x{out.data_ptr():x}")
print(out.stride())
torch.testing.assert_close(res1.cpu(), res2)
0x113f70cc0
(1, 3, 9)
False
0x114f3a510
0x113f70cc0
(1, 3, 9)
There was a problem hiding this comment.
@Isalia20 , could you remove permute so that out is contiguous? In the Meta function, as per your modification, out is re-used only if it is contiguous.
There was a problem hiding this comment.
Ah I see the issue now:
0x10bc7b840
(9, 3, 1)
True
0x10bc7b840
0x10bc7b840
(9, 1, 3)
|
|
||
| // L | ||
| auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/true); | ||
| auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/A.device().type() != at::kMPS); |
There was a problem hiding this comment.
MPS Kernel assumes row-major layout for the matrix where it does the decomposition
There was a problem hiding this comment.
Can the kernel be made row-major/col-major agnostic so as to be able preserve the consistency across backends?
There was a problem hiding this comment.
I'll take a look ~next week to see if I can make it work for col-major so we don't need to make it row major for MPS only, but why do we want to preserve consistency across backends? Lot of ops on MPS use row major layout and require contiguous call on it before passing it to some MPS kernel
There was a problem hiding this comment.
In linalg LAPACK seems like the source of truth, and it is written in Fortran where col-major is the standard layout :(
There was a problem hiding this comment.
I believe we can re-use the kernel without that much code change (i.e. no need to make it stride-agnostic for now). In the Meta function we request C-contiguous when upper=False and F-contiguous when upper=True for the MPS. Then we only need to remove the line upper ? out.transpose_(...) : out (and probably replace it with out.tril_() : out.triu_(). Or something along these lines. Should resolve the issue for now with out, before the kernel is adapted for better memory accesses when in column-major mode...
There was a problem hiding this comment.
I've tried it but I'm afraid it doesn't work. I'll address this in the followup PR with the kernel change for column major mode rather than going into the rabbit hole now for a temporary fix
|
Thanks, I'll address the comments a little later today |
| output_cpu = torch.linalg.cholesky_ex(input_cpu, upper=upper) | ||
| output_mps = torch.linalg.cholesky_ex(input_mps, upper=upper) |
There was a problem hiding this comment.
Let us also check that info is the same since its behavior is altered?
There was a problem hiding this comment.
output_cpu and output_mps is a tuple of L and info tensors so assertEqual is comparing both of them. Do you mean to add a separate test where info might be >1?
There was a problem hiding this comment.
Yes, when erroring on non-psd inputs :)
There was a problem hiding this comment.
I'll do it a bit later today and also adapt the error message
There was a problem hiding this comment.
Added better error message
|
@pytorchbot merge -f "MPS is green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
PR #145701 didn't have experimental version of cholesky. This PR adds that version