BUG: validate contraction axes in tensordot#30521
Conversation
tensordot performs a tensor contraction where each axis corresponds to a distinct summation index. Duplicate contraction axes (e.g. axes=([1, 1], [0, 0])) are mathematically invalid but currently fail later during an internal transpose, raising ValueError: axes don't match array. This PR adds early validation to reject duplicate axes explicitly, aligning the behavior with the definition of tensor contraction and other NumPy axis-handling APIs. Current behavior: Supplying duplicate contraction axes (e.g. axes=([1, 1], [0, 0])) causes tensordot to fail later during an internal transpose, raising ValueError: axes don't match array. Expected behavior: Duplicate contraction axes should be rejected explicitly, since tensor contraction requires distinct summation axes, and a clear input-validation error should be raised before any internal reshaping or transposition.
|
Exact case I used
|
|
Hello, following up with a gentle reminder on this PR. All checks are passing, and I’m open to any discussion or feedback. Thank you for your consideration. |
|
Seems fine to me. Can you update the docstring to indicate that repeated entries in Also, IMO the explanation about |
|
@ngoldbaum Thanks for the review! Yes, agreed ,I’ll update the docstring to explicitly state that repeated entries in axes are not allowed. I’ll also try to clarify the explanatory note by adding a couple of small concrete examples to make the contraction semantics easier to follow. I’ll push an update shortly. |
|
@ngoldbaum |
BUG: validate contraction axes in tensordot (#30521)
tensordot performs a tensor contraction where each axis corresponds to a distinct summation index. Duplicate contraction axes (e.g. axes=([1, 1], [0, 0])) are mathematically invalid but currently fail later during an internal transpose, raising ValueError: axes don't match array. This PR adds early validation to reject duplicate axes explicitly, aligning the behavior with the definition of tensor contraction and other NumPy axis-handling APIs.
Current behavior:
Supplying duplicate contraction axes (e.g. axes=([1, 1], [0, 0])) causes tensordot to fail later during an internal transpose, raising ValueError: axes don't match array.
Expected behavior:
Duplicate contraction axes should be rejected explicitly, since tensor contraction requires distinct summation axes, and a clear input-validation error should be raised before any internal reshaping or transposition.