Skip to content

Lower aten::_linalg_eigh#7674

Merged
tengyifei merged 1 commit intopytorch:masterfrom
tengyifei:linalg.eigh
Jul 15, 2024
Merged

Lower aten::_linalg_eigh#7674
tengyifei merged 1 commit intopytorch:masterfrom
tengyifei:linalg.eigh

Conversation

@tengyifei
Copy link
Copy Markdown
Collaborator

@tengyifei tengyifei commented Jul 12, 2024

Fixes #6017

There's an xla::SelfAdjointEig function so we lower it to that.

I discovered that the XLA implementation of eigenvalue decomposition is not as numerically stable as numpy or torch, despite passing a small tolerance and large max_iter. The unit test thus uses a hardcoded tensor value.

@tengyifei
Copy link
Copy Markdown
Collaborator Author

cc @vanbasten23

Comment thread test/cpp/test_aten_xla_tensor_2.cpp Outdated
Comment thread torch_xla/csrc/ops/eigh.cpp Outdated
if (!compute_v) {
// Fallback to aten in case of `eigvalsh`, which does not compute
// eigenvectors but requires numerically stable gradients.
return at::native::call_fallback_fn<&xla_fallback,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand we only need to lower torch.linalg.eigh. But in case we need to lower eigvalsh later, would requires numerically stable gradients be a blocker so we have to fall back to aten?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I read the PyTorch docs again and I think I misunderstood it initially. What PyTorch doc suggests is that the gradients of the eigenvectors are unstable. Therefore, if the user calls eigvalsh, they will only get eigenvalues and thus the gradients will be stable. I removed this misleading comment.

If we want to support eigvalsh later the simplest way is probably discarding the eigenvectors from XLA and also figuring out what to return for the second at::Tensor tuple member.

Comment thread torch_xla/csrc/ops/eigh.cpp Outdated

std::array<xla::XlaOp, 2> LowerImpl(xla::XlaOp input, bool lower) {
auto [eigenvectors, eigenvalues] =
xla::SelfAdjointEig(input, lower, /* max_iter */ 64, /* tol */ 1e-6);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the max_iter and tol (for opaque) used for?
Also, could you add a comment on why changing the default value?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When testing I discovered that the default settings lead to a very low accuracy in the reconstructed matrix (e.g. let's say we decompose A to A' = V @ Q @ V_T, then A and A' have a difference in elements above 0.1.

After looking at what JAX does I think it can be simpler to align with JAX: https://github.com/google/jax/blob/a8b425cac50c842f66f36903dfb93fe6ad5a2a5b/jax/_src/lax/linalg.py#L726. Looks like they use the same tolerance but higher max_iter.

There's an xla::SelfAdjointEig function so we lower it to that.

I discovered that the XLA implementation of eigenvalue decomposition
is not as numerically stable as numpy or torch, despite passing a
small tolerance and large max_iter. The unit test thus uses a
hardcoded tensor value copied from

https://android.googlesource.com/platform/external/tensorflow/+/f2a058296dd/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc#149
@tengyifei tengyifei merged commit f975ad6 into pytorch:master Jul 15, 2024
@miladm miladm assigned miladm and tengyifei and unassigned miladm Jul 16, 2024
@miladm miladm added the lowering ATen Operation lowering label Jul 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

lowering ATen Operation lowering

Projects

None yet

Development

Successfully merging this pull request may close these issues.

not lowered: aten::_linalg_eigh

4 participants