Lower aten::_linalg_eigh#7674
Conversation
|
cc @vanbasten23 |
| if (!compute_v) { | ||
| // Fallback to aten in case of `eigvalsh`, which does not compute | ||
| // eigenvectors but requires numerically stable gradients. | ||
| return at::native::call_fallback_fn<&xla_fallback, |
There was a problem hiding this comment.
I understand we only need to lower torch.linalg.eigh. But in case we need to lower eigvalsh later, would requires numerically stable gradients be a blocker so we have to fall back to aten?
There was a problem hiding this comment.
So I read the PyTorch docs again and I think I misunderstood it initially. What PyTorch doc suggests is that the gradients of the eigenvectors are unstable. Therefore, if the user calls eigvalsh, they will only get eigenvalues and thus the gradients will be stable. I removed this misleading comment.
If we want to support eigvalsh later the simplest way is probably discarding the eigenvectors from XLA and also figuring out what to return for the second at::Tensor tuple member.
|
|
||
| std::array<xla::XlaOp, 2> LowerImpl(xla::XlaOp input, bool lower) { | ||
| auto [eigenvectors, eigenvalues] = | ||
| xla::SelfAdjointEig(input, lower, /* max_iter */ 64, /* tol */ 1e-6); |
There was a problem hiding this comment.
What are the max_iter and tol (for opaque) used for?
Also, could you add a comment on why changing the default value?
There was a problem hiding this comment.
When testing I discovered that the default settings lead to a very low accuracy in the reconstructed matrix (e.g. let's say we decompose A to A' = V @ Q @ V_T, then A and A' have a difference in elements above 0.1.
After looking at what JAX does I think it can be simpler to align with JAX: https://github.com/google/jax/blob/a8b425cac50c842f66f36903dfb93fe6ad5a2a5b/jax/_src/lax/linalg.py#L726. Looks like they use the same tolerance but higher max_iter.
There's an xla::SelfAdjointEig function so we lower it to that. I discovered that the XLA implementation of eigenvalue decomposition is not as numerically stable as numpy or torch, despite passing a small tolerance and large max_iter. The unit test thus uses a hardcoded tensor value copied from https://android.googlesource.com/platform/external/tensorflow/+/f2a058296dd/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc#149
Fixes #6017
There's an xla::SelfAdjointEig function so we lower it to that.
I discovered that the XLA implementation of eigenvalue decomposition is not as numerically stable as numpy or torch, despite passing a small tolerance and large max_iter. The unit test thus uses a hardcoded tensor value.