Fix torch._numpy advanced indexing to match NumPy when indices are separated#157676
Fix torch._numpy advanced indexing to match NumPy when indices are separated#157676
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157676
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d7dd51b with merge base 2022588 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Pretty much the exact opposite, I believe. |
|
@soumith can you post the prompts |
ezyang
left a comment
There was a problem hiding this comment.
This appears to not be right. I'm also not sure if it's right to keep attempting to work on this PR or re-prompt.
5628934 to
3b91c18
Compare
|
uhh, i think this is actually ready for review now. |
|
combined_transcripts.html.zip You'd want to focus on the last 3 conversations, because once I gave the numpy advanced indexing spec (from the docs), it finally started hammering the right code. |
|
@ezyang added two commits that ground the implementation now in numpy's exact original implementation, as well as add some more edge cases to cover some subtle things missed before. |
This makes me happy :) |
torch/_numpy/_ndarray.py
Outdated
| FANCY = 16 | ||
| BOOL = 32 | ||
| SCALAR_ARRAY = 64 | ||
| BOOL_0D = FANCY | 128 |
There was a problem hiding this comment.
I was going to suggest auto() but with BOOL_0D ehhh better not
torch/_numpy/_ndarray.py
Outdated
| BOOL_0D = FANCY | 128 | ||
|
|
||
|
|
||
| def _classify_index(idx): |
There was a problem hiding this comment.
hmmm... this is prepare_index I guess?
There was a problem hiding this comment.
I guess prepare_index is the entire state machine
There was a problem hiding this comment.
You should ask Claude to add types
torch/_numpy/_ndarray.py
Outdated
| # Handle tensors with dtype/ndim checks | ||
| if isinstance(idx, torch.Tensor): | ||
| if idx.dtype == torch.bool: | ||
| return IndexType.BOOL_0D if idx.ndim == 0 else IndexType.BOOL |
There was a problem hiding this comment.
You detect BOOL_0D but this is never actually used in the body
|
Unfortunately, it's hard to tell if Claude has done it correctly because it has taken extensive liberties in changing the logic, where as a dumb reviewer I just wanted to see if it had done a straight port of the Numpy C code to PyTorch correctly. Maybe a prompt along those lines would work? |
|
@ezyang I had it generate a side-by-side of the Numpy C code and the code that it added, so that's it's way easier for you to compare and review. |
|
and lastly, I made it add types as you had asked |
|
I am not sure all of this is needed. The added tests definitely, but I think the core issue here is that The following should show the current difference: But that's because in the presence of advanced indexing, PyTorch already imitates Numpy's behavior if we pass Here is a more complex advanced indexing example just to spot check that PyTorch is indeed doing the same as Numpy: |
…behavior the core issue here is that `PyTorch` treats the last `0` as a slice, while `Numpy` treats it as advanced indexing. It should suffice to change that `0` to `[0]` **when other advanced indexing is present**. PyTorch is already producing the same output shape than Numpy, it just doesn't know that `0` means `[0]`. To explain what I mean: ``` >>> import numpy as np >>> import torch >>> np_x = np.arange(2*3*5*7).reshape(2, 3, 5, 7) >>> py_x = torch.arange(2*3*5*7).reshape(2, 3, 5, 7) ``` The following should show the current difference: ``` >>> py_x[:, [1, 0, 2, 2], :, 0].shape torch.Size([2, 4, 5]) >>> np_x[:, [1, 0, 2, 2], :, 0].shape (4, 2, 5) ``` But that's because in the presence of advanced indexing, `Numpy` treats the `0` as advanced indexing too, i.e as `[0]` ``` >>> np_x[:, [1, 0, 2, 2], :, [0]].shape (4, 2, 5) ``` PyTorch already imitates Numpy's behavior if we pass `[0]` ``` >>> py_x[:, [1, 0, 2, 2], :, [0]].shape torch.Size([4, 2, 5]) ``` Here is a more complex advanced indexing example just to spot check that PyTorch is indeed doing the same as Numpy: ``` >>> np_x[:, [1, 0, 2, 2], :, [[5], [3], [1], [2], [4], [0]]].shape (6, 4, 2, 5) >>> py_x[:, [1, 0, 2, 2], :, [[5], [3], [1], [2], [4], [0]]].shape torch.Size([6, 4, 2, 5]) ``` Add comprehensive tests for torch._numpy advanced indexing Tests verify NumPy compatibility for complex indexing patterns including: - Single and multiple separated advanced indices - Adjacent vs separated index behavior - Edge cases with negative indices and broadcasting - Both getitem and setitem operations with scalars and arrays 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
|
@manuelcandales is right (I don't know how long he's spent on advanced indexing to get to this, i can still barely wrap my head around numpy's advanced indexing logic). I've just pushed a commit that makes the bugfix trivial, I did keep all the test additions though, seems prudent |
| # Note: Do NOT convert boolean scalars (True/False) as they have special meaning in NumPy | ||
| converted = [] | ||
| for idx in index: | ||
| if isinstance(idx, int) and not isinstance(idx, bool): |
There was a problem hiding this comment.
The same should be done for integers wrapped around a tensor (i.e. zero-dimensions tensor holding an integer)
There was a problem hiding this comment.
fixed and added test cases
torch/_numpy/_ndarray.py
Outdated
|
|
||
| # Check if there's any advanced indexing (lists or multi-dimensional tensors) | ||
| has_advanced = any( | ||
| isinstance(idx, list) or (isinstance(idx, torch.Tensor) and idx.ndim > 0) |
There was a problem hiding this comment.
It turns out you also need to add isinstance(idx, bool) to the conditions for advanced indexing.
I have created issue #158134 to document this edge case, and the discrepancy between Numpy/PyTorch.
Add tests from that issue's description here as well.
There was a problem hiding this comment.
Actually, you need to check for True/False in either form (boolean or tensor). So, you also need to check for 0-dimensional tensors of boolean dtype.
There was a problem hiding this comment.
tuples are also allowed, and trigger advanced indexing just like lists do
There was a problem hiding this comment.
We need to capture lists of things more generally. I think it should be isinstance(idx, Sequence) instead of isinstance(idx, list)
So, to summarize. I think the advanced indexing conditions should be:
isinstance(idx, Sequence) or isinstance(idx, bool)
or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))
ndarrays are also allowed, but those are converted to tensors before you call this function.
|
@pytorchbot merge |
|
This PR has pending changes requested. Please address the comments and update the PR before merging. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ndexing separation (#158297) Fixes #141563 In NumPy, an ellipsis always acts as a separator between advanced indices, even when the ellipsis doesn't actually match any dimensions. In PyTorch an empty ellipsis doesn't cause a separation. This leads to differing behavior between Numpy and PyTorch in this edge case. This difference in behavior leads to a bug when using torch.compile: ```python >>> import numpy as np >>> f = lambda x: x[:,(0,1),...,(0,1)].shape >>> a = np.ones((3, 4, 5)) >>> f(a) (2, 3) >>> torch.compile(f)(a) (3, 2) ``` Similarly to #157676, this PR doesn't change PyTorch's behavior, but it fixes the translation layer, ensuring torch._numpy compatibility with NumPy. I am marking this PR as fixing #141563, even though PyTorch behavior isn't modified. Notice that there are still some other bugs in PyTorch's advanced indexing, that need to be fixed (mainly regarding proper accounting of dimensions when multidimensional boolean masks are present). But those need to be fixed at the ATen operator level. Examples: - #71673 - #107699 - #158125 Pull Request resolved: #158297 Approved by: https://github.com/soumith
Written with Claude Code.
Fixes #157569
Fixes #158134
NumPy and PyTorch handle advanced indexing differently when advanced indices are separated by slices (e.g., arr[:, [0], :, 0]). PyTorch uses "outer" indexing placing result dimensions in original positions, while NumPy uses "vectorized"
indexing moving advanced index dimensions to the front.
This adds _numpy_style_advanced_indexing() to detect separated advanced indices and transpose results to match NumPy's dimension ordering, ensuring torch._numpy maintains compatibility with NumPy's indexing behavior.
Fixes cases like:
cc @mruberry @rgommers @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames