index: fix index of 0-element tensor by 0-element tensor.#7113
index: fix index of 0-element tensor by 0-element tensor.#7113
index: fix index of 0-element tensor by 0-element tensor.#7113Conversation
lezcano
left a comment
There was a problem hiding this comment.
We have a meta implementation of the meta of index_Tensor at https://github.com/pytorch/pytorch/blob/ee6cb6daa173896f8ea1876266a19775aaa4f610/torch/_meta_registrations.py#L2983
Perhaps it's better to replicate its behaviour here, as it's quite tricky.
I assume we cannot use this information here, even though dynamo knows about it, right?
|
I think that I'm basically doing the same thing that this part of the meta implementation is doing. The only thing is that, at this point, we have already preprocessed the inputs such that:
I think the problem is two-fold:
|
lezcano
left a comment
There was a problem hiding this comment.
Sounds reasonable. Thank you for looking into this!
|
Thank you, Mario. I think I will wait for @JackCaoG review, too, before merging this PR. |
|
Let me take a look today |
This PR adds support for indexing a 0-element tensor with a 0-element tensor index. It also adds a fast path whenever there are 0-element tensor indices. If found, we know that we will return a 0-element tensor.
In summary, here are the steps of this fast path:
indicesandstart_dimstart_dimstart_dim + len(indices)For further details, see PyTorch's implementation of output shape computation.
cc @miladm @JackCaoG @lezcano