Skip to content

Fix torch._numpy advanced indexing to match NumPy when indices are separated#157676

Closed
soumith wants to merge 4 commits intomainfrom
np_indexing_fix
Closed

Fix torch._numpy advanced indexing to match NumPy when indices are separated#157676
soumith wants to merge 4 commits intomainfrom
np_indexing_fix

Conversation

@soumith
Copy link
Collaborator

@soumith soumith commented Jul 6, 2025

Written with Claude Code.

Fixes #157569
Fixes #158134

NumPy and PyTorch handle advanced indexing differently when advanced indices are separated by slices (e.g., arr[:, [0], :, 0]). PyTorch uses "outer" indexing placing result dimensions in original positions, while NumPy uses "vectorized"
indexing moving advanced index dimensions to the front.

This adds _numpy_style_advanced_indexing() to detect separated advanced indices and transpose results to match NumPy's dimension ordering, ensuring torch._numpy maintains compatibility with NumPy's indexing behavior.

Fixes cases like:

  • arr[:, [0], :, 0] now returns shape (1, 5, 7) instead of (5, 1, 7)
  • arr[:, [0, 1], :, 0] now returns shape (2, 5, 7) instead of (5, 2, 7)

cc @mruberry @rgommers @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157676

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d7dd51b with merge base 2022588 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@soumith soumith added module: numpy Related to numpy support, and also numpy compatibility of our operators topic: not user facing topic category module: dynamo labels Jul 6, 2025
@francois-rozet
Copy link

NumPy uses "outer" indexing placing result dimensions in original positions, while PyTorch uses "vectorized" indexing moving advanced index dimensions to the front.

Pretty much the exact opposite, I believe.

@soumith soumith force-pushed the np_indexing_fix branch from 5b9e7e7 to 84e2e49 Compare July 6, 2025 19:27
@ezyang
Copy link
Contributor

ezyang commented Jul 7, 2025

@soumith can you post the prompts

ezyang
ezyang previously requested changes Jul 7, 2025
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to not be right. I'm also not sure if it's right to keep attempting to work on this PR or re-prompt.

@soumith soumith force-pushed the np_indexing_fix branch 3 times, most recently from 5628934 to 3b91c18 Compare July 8, 2025 02:17
@soumith
Copy link
Collaborator Author

soumith commented Jul 8, 2025

uhh, i think this is actually ready for review now.

@soumith soumith requested a review from ezyang July 8, 2025 02:24
@soumith soumith force-pushed the np_indexing_fix branch from 3b91c18 to 954b872 Compare July 8, 2025 02:55
@soumith
Copy link
Collaborator Author

soumith commented Jul 8, 2025

combined_transcripts.html.zip
claude code prompts here.

You'd want to focus on the last 3 conversations, because once I gave the numpy advanced indexing spec (from the docs), it finally started hammering the right code.

@soumith soumith requested a review from ezyang July 10, 2025 01:36
@soumith
Copy link
Collaborator Author

soumith commented Jul 10, 2025

@ezyang added two commits that ground the implementation now in numpy's exact original implementation, as well as add some more edge cases to cover some subtle things missed before.
The implementation becomes a bit more verbose because it's matching numpy's state machine implementation exactly, but that's the cost of grounding, which seems fine and acceptable.
take a look again.

@ezyang
Copy link
Contributor

ezyang commented Jul 10, 2025

The implementation becomes a bit more verbose because it's matching numpy's state machine implementation exactly, but that's the cost of grounding, which seems fine and acceptable.

This makes me happy :)

FANCY = 16
BOOL = 32
SCALAR_ARRAY = 64
BOOL_0D = FANCY | 128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to suggest auto() but with BOOL_0D ehhh better not

BOOL_0D = FANCY | 128


def _classify_index(idx):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm... this is prepare_index I guess?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess prepare_index is the entire state machine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should ask Claude to add types

# Handle tensors with dtype/ndim checks
if isinstance(idx, torch.Tensor):
if idx.dtype == torch.bool:
return IndexType.BOOL_0D if idx.ndim == 0 else IndexType.BOOL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You detect BOOL_0D but this is never actually used in the body

@ezyang
Copy link
Contributor

ezyang commented Jul 10, 2025

Unfortunately, it's hard to tell if Claude has done it correctly because it has taken extensive liberties in changing the logic, where as a dumb reviewer I just wanted to see if it had done a straight port of the Numpy C code to PyTorch correctly. Maybe a prompt along those lines would work?

@soumith
Copy link
Collaborator Author

soumith commented Jul 10, 2025

@ezyang I had it generate a side-by-side of the Numpy C code and the code that it added, so that's it's way easier for you to compare and review.
Then I also made it generate a side-by-side of the refactor of the C-transpiled code to the Pythonified code.
Gist is here: https://gist.github.com/soumith/a13ee05e6b461771062ec91e78e2c22a

@soumith
Copy link
Collaborator Author

soumith commented Jul 10, 2025

and lastly, I made it add types as you had asked

@soumith soumith requested a review from ezyang July 10, 2025 13:46
@manuelcandales
Copy link
Contributor

I am not sure all of this is needed. The added tests definitely, but I think the core issue here is that PyTorch treats the last 0 as a slice, while Numpy treats it as advanced indexing. It should suffice to change that 0 to [0] when other advanced indexing is present. PyTorch is already producing the same output shape than Numpy, it just doesn't know that 0 means [0].
To explain what I mean:

>>> import numpy as np
>>> import torch
>>> np_x = np.arange(2*3*5*7).reshape(2, 3, 5, 7)
>>> py_x = torch.arange(2*3*5*7).reshape(2, 3, 5, 7)

The following should show the current difference:

>>> py_x[:, [1, 0, 2, 2], :, 0].shape
torch.Size([2, 4, 5])
>>> np_x[:, [1, 0, 2, 2], :, 0].shape
(4, 2, 5)

But that's because in the presence of advanced indexing, Numpy treats the 0 as advanced indexing too, i.e as [0]

>>> np_x[:, [1, 0, 2, 2], :, [0]].shape
(4, 2, 5)

PyTorch already imitates Numpy's behavior if we pass [0]

>>> py_x[:, [1, 0, 2, 2], :, [0]].shape
torch.Size([4, 2, 5])

Here is a more complex advanced indexing example just to spot check that PyTorch is indeed doing the same as Numpy:

>>> np_x[:, [1, 0, 2, 2], :, [[5], [3], [1], [2], [4], [0]]].shape
(6, 4, 2, 5)
>>> py_x[:, [1, 0, 2, 2], :, [[5], [3], [1], [2], [4], [0]]].shape
torch.Size([6, 4, 2, 5])

…behavior

the core issue here is that `PyTorch` treats the last `0` as a slice, while `Numpy` treats it as advanced indexing. It should suffice to change that `0` to `[0]` **when other advanced indexing is present**. PyTorch is already producing the same output shape than Numpy, it just doesn't know that `0` means `[0]`.
To explain what I mean:
```
>>> import numpy as np
>>> import torch
>>> np_x = np.arange(2*3*5*7).reshape(2, 3, 5, 7)
>>> py_x = torch.arange(2*3*5*7).reshape(2, 3, 5, 7)
```

The following should show the current difference:
```
>>> py_x[:, [1, 0, 2, 2], :, 0].shape
torch.Size([2, 4, 5])
>>> np_x[:, [1, 0, 2, 2], :, 0].shape
(4, 2, 5)
```

But that's because in the presence of advanced indexing, `Numpy` treats the `0` as advanced indexing too, i.e as `[0]`
```
>>> np_x[:, [1, 0, 2, 2], :, [0]].shape
(4, 2, 5)
```

PyTorch already imitates Numpy's behavior if we pass `[0]`
```
>>> py_x[:, [1, 0, 2, 2], :, [0]].shape
torch.Size([4, 2, 5])
```

Here is a more complex advanced indexing example just to spot check that PyTorch is indeed doing the same as Numpy:
```
>>> np_x[:, [1, 0, 2, 2], :, [[5], [3], [1], [2], [4], [0]]].shape
(6, 4, 2, 5)
>>> py_x[:, [1, 0, 2, 2], :, [[5], [3], [1], [2], [4], [0]]].shape
torch.Size([6, 4, 2, 5])
```

Add comprehensive tests for torch._numpy advanced indexing

Tests verify NumPy compatibility for complex indexing patterns including:
- Single and multiple separated advanced indices
- Adjacent vs separated index behavior
- Edge cases with negative indices and broadcasting
- Both getitem and setitem operations with scalars and arrays

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
@soumith
Copy link
Collaborator Author

soumith commented Jul 11, 2025

@manuelcandales is right (I don't know how long he's spent on advanced indexing to get to this, i can still barely wrap my head around numpy's advanced indexing logic).

I've just pushed a commit that makes the bugfix trivial, I did keep all the test additions though, seems prudent

# Note: Do NOT convert boolean scalars (True/False) as they have special meaning in NumPy
converted = []
for idx in index:
if isinstance(idx, int) and not isinstance(idx, bool):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same should be done for integers wrapped around a tensor (i.e. zero-dimensions tensor holding an integer)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed and added test cases

@soumith soumith requested a review from manuelcandales July 11, 2025 17:32

# Check if there's any advanced indexing (lists or multi-dimensional tensors)
has_advanced = any(
isinstance(idx, list) or (isinstance(idx, torch.Tensor) and idx.ndim > 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out you also need to add isinstance(idx, bool) to the conditions for advanced indexing.
I have created issue #158134 to document this edge case, and the discrepancy between Numpy/PyTorch.
Add tests from that issue's description here as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you need to check for True/False in either form (boolean or tensor). So, you also need to check for 0-dimensional tensors of boolean dtype.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tuples are also allowed, and trigger advanced indexing just like lists do

Copy link
Contributor

@manuelcandales manuelcandales Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to capture lists of things more generally. I think it should be isinstance(idx, Sequence) instead of isinstance(idx, list)

So, to summarize. I think the advanced indexing conditions should be:

isinstance(idx, Sequence) or isinstance(idx, bool)
or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))

ndarrays are also allowed, but those are converted to tensors before you call this function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@soumith
Copy link
Collaborator Author

soumith commented Jul 11, 2025

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 11, 2025

This PR has pending changes requested. Please address the comments and update the PR before merging.

@soumith soumith dismissed ezyang’s stale review July 12, 2025 00:03

already addressed

@soumith
Copy link
Collaborator Author

soumith commented Jul 12, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 12, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jul 16, 2025
…ndexing separation (#158297)

Fixes #141563

In NumPy, an ellipsis always acts as a separator between advanced indices, even when the ellipsis doesn't actually match any dimensions. In PyTorch an empty ellipsis doesn't cause a separation. This leads to differing behavior between Numpy and PyTorch in this edge case.

This difference in behavior leads to a bug when using torch.compile:
```python
>>> import numpy as np
>>> f = lambda x: x[:,(0,1),...,(0,1)].shape
>>> a = np.ones((3, 4, 5))
>>> f(a)
(2, 3)
>>> torch.compile(f)(a)
(3, 2)
```

Similarly to #157676, this PR doesn't change PyTorch's behavior, but it fixes the translation layer, ensuring torch._numpy compatibility with NumPy. I am marking this PR as fixing #141563, even though PyTorch behavior isn't modified.

Notice that there are still some other bugs in PyTorch's advanced indexing, that need to be fixed (mainly regarding proper accounting of dimensions when multidimensional boolean masks are present). But those need to be fixed at the ATen operator level. Examples:
- #71673
- #107699
- #158125

Pull Request resolved: #158297
Approved by: https://github.com/soumith
@github-actions github-actions bot deleted the np_indexing_fix branch August 12, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: numpy Related to numpy support, and also numpy compatibility of our operators topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Discrepancy between Numpy and PyTorch advanced indexing torch.compile with numpy code differs from numpy's behavior

5 participants