Support torch.linalg.trace#62714
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 9622b99 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
| Job | Step | Action |
|---|---|---|
| Run clang-format | 🔁 rerun |
This comment was automatically generated by Dr. CI (expand for details).
Please report bugs/suggestions to the (internal) Dr. CI Users group.
|
Hey @asi1024, just checking in on this PR because it's still marked as "draft." Is it ready for a review? |
|
@mruberry Sorry for my late response. I will mark "ready for review" after adding tests and documentation! |
IvanYashchuk
left a comment
There was a problem hiding this comment.
Hey, @asi1024, thanks for your contribution! I left a suggestion to use the CompositeImplicitAutograd dispatch key that would allow us to remove the _backward function trimming down unnecessary code. After that, I think the PR should be good to go.
There was a problem hiding this comment.
This line is overwritten by CompositeExplicitAutograd: linalg_trace. The code in the linalg_trace function is independent of the device so CPU, CUDA specialization is not needed here and CompositeExplicitAutograd is the correct choice of the dispatch key.
| CPU, CUDA: linalg_trace |
There was a problem hiding this comment.
On a second thought using CompositeImplicitAutograd should be better, then the backward function is not needed.
There was a problem hiding this comment.
Could you please remove this entry from native_functions.yaml?
Most of the backward functions in PyTorch are placed in torch\csrc\autograd\FunctionsManual.cpp and torch\csrc\autograd\FunctionsManual.h, so let's move linalg_trace_backward from ReduceOps.cpp.
There was a problem hiding this comment.
Is this change needed in this PR?
The function prod(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor should be autogenerated with
Codecov Report
@@ Coverage Diff @@
## master #62714 +/- ##
==========================================
+ Coverage 66.37% 66.46% +0.08%
==========================================
Files 738 727 -11
Lines 94170 93581 -589
==========================================
- Hits 62510 62200 -310
+ Misses 31660 31381 -279 |
lezcano
left a comment
There was a problem hiding this comment.
Left a few points regarding docs / testing.
There was a problem hiding this comment.
What diagonal? Given that we have the parameter offset, this should probably read:
Returns the sum of the elements of a diagonal.
Followed by an explanation of how the offset parameter chooses a diagonal.
There was a problem hiding this comment.
The call to torch.from_numpy here and a few lines below is not necessary as assertEqual is able to compare tensors and numpy arrays. Even more, not calling it is often faster. Same below.
There was a problem hiding this comment.
This might work without all the explicit castings by simply doing
| xn = np.array(x.cpu().numpy()).reshape(shape) | |
| yn = np.trace(xn, axis1=-2, axis2=-1) | |
| yn = np.trace(x.cpu(), axis1=-2, axis2=-1) |
Same below.
|
@lezcano Thank you for your reviews! Could you take another look? |
|
@lezcano Now all CIs have passed! PTAL! |
|
Yeah, this looks good to me. I just found (yet another, ugh, sorry). Sorry for that! I believe this is the last missing thing! :D |
|
@lezcano The CI failures look unrelated to this PR. Could you take another look? |
|
As mentioned, this LGTM. We now just need to wait for @mruberry to have a look. He's been a bit busy lately, but let's hope he finds some time soon :) |
| trace = _add_docstr(_linalg.linalg_trace, r""" | ||
| trace(input, *, offset=0, out=None) -> Tensor | ||
|
|
||
| Computes the trace of a matrix. |
There was a problem hiding this comment.
This is really well written.
| inputs = ( | ||
| ((S, S), 0), | ||
| ((S, M), 0), | ||
| ((S, S), 1), |
There was a problem hiding this comment.
Add a sample with a negative offset and a comment explaining the format of these tuples
| sample_inputs_func=sample_inputs_linalg_slogdet, | ||
| decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],), | ||
| OpInfo('linalg.trace', | ||
| ref=np.trace, |
| an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0)) | ||
| self.assertEqual(a, an) | ||
|
|
||
| def test_linalg_trace(self, device): |
|
|
||
| def sample_inputs_linalg_trace(self, device, dtype, requires_grad, **kwargs): | ||
| inputs = ( | ||
| ((S, S), 0), |
There was a problem hiding this comment.
What happens on empty tensors? We should probably add a case for them. What about a batched sample input too? (S, S, S)?
There was a problem hiding this comment.
The trace function implemented in this PR returns different values from numpy.trace for 3-dim inputs. numpy.trace reduces with axis1=0, axis2=1 whereas array API specifies to reduce with axis1=-2, axis2=-1.
There was a problem hiding this comment.
Is it possible to compare then against a lambda with our same defaults that the calls into np.trace?
| def test_linalg_trace(self, device): | ||
| inputs = [ | ||
| {'shape': (1, 1), 'offsets': [0]}, | ||
| {'shape': (10, 1), 'offsets': [0, -9]}, |
There was a problem hiding this comment.
What happens if offset is an absurd number, like 100?
There was a problem hiding this comment.
RuntimeError will be raised if offset is out of range. I will add a test for this case!
There was a problem hiding this comment.
RuntimeErrorwill be raised ifoffsetis out of range. I will add a test for this case!
Make it an ErrorInput
|
|
||
| def test_linalg_trace(self, device): | ||
| inputs = [ | ||
| {'shape': (1, 1), 'offsets': [0]}, |
There was a problem hiding this comment.
Adding the empty case here would be interesting, too
| device_check: NoCheck | ||
| device_guard: False | ||
|
|
||
| - func: linalg_trace.out(Tensor self, *, int offset=0, Tensor(a!) out) -> Tensor(a!) |
| // see https://github.com/pytorch/pytorch/pull/47305, | ||
| Tensor linalg_trace(const Tensor& self, int64_t offset) { | ||
| TORCH_CHECK(self.dim() >= 2, | ||
| "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); |
There was a problem hiding this comment.
The user documented name is input (per your docs below) and these warnings should start with the name of the operation like this:
torch.linalg.trace(): input should have at least...
There was a problem hiding this comment.
It might be nice to change the user-facing name of this argument to A, which is the name we use throughout torch.linalg
|
@mruberry Updated tests. PTAL! |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Fixes #62255 (cc/ @mruberry, @rgommers, @emcastillo, @kmaehashi)
This PR adds support of
torch.linalg.tracefor the compatibility with NumPy's interface and Python array API standard.TODO:
cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano