Skip to content

cat: improve error handling and error messages.#9548

Merged
ysiraichi merged 4 commits intomasterfrom
ysiraichi/better-error-cat
Aug 11, 2025
Merged

cat: improve error handling and error messages.#9548
ysiraichi merged 4 commits intomasterfrom
ysiraichi/better-error-cat

Conversation

@ysiraichi
Copy link
Copy Markdown
Collaborator

This PR refactors the tensor_methods::cat implementation by improving its error message, and returning a status type value.

Key Changes:

  • Make tensor_methods::cat return StatusOr<absl_nonnull XLATensorPtr>
  • Improve error message on incompatible tensor shapes
  • Crash on torch.cat with no tensors

Before:

Traceback (most recent call last):
  File "scratch.py", line 8, in <module>
    print(torch.cat([x, y]))
          ^^^^^^^^^^^^^^^^^
RuntimeError: Check failed: xla::ShapeUtil::CompatibleIgnoringElementType(shapes.back(), tensor_shape): f32[1]{0} vs. f32[4]{0} (at torch_xla/csrc/tensor_methods.cpp:1185)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "scratch.py", line 8, in <module>
    print(torch.cat([x, y]))
          ^^^^^^^^^^^^^^^^^
RuntimeError: cat(): cannot concatenate tensors of shape f32[1,1] with f32[9,4] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,).

Status Propagation Trace:
    From: cat at torch_xla/csrc/tensor_methods.cpp:1190 (error: cat(): cannot concatenate tensors of shape f32[1,1] with f32[9,4] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,).)

Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:128 (most recent call first):

Copy link
Copy Markdown
Collaborator

@zhanyong-wan zhanyong-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@ysiraichi ysiraichi merged commit 38e0f03 into master Aug 11, 2025
23 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants