Skip to content

Advanced Indexing Part 1 -- Purely Integer Array Indexing#1588

Closed
killeent wants to merge 30 commits intopytorch:masterfrom
killeent:advanced-indexing
Closed

Advanced Indexing Part 1 -- Purely Integer Array Indexing#1588
killeent wants to merge 30 commits intopytorch:masterfrom
killeent:advanced-indexing

Conversation

@killeent
Copy link
Contributor

@killeent killeent commented May 18, 2017

This PR implements "Purely Integer Array" advanced indexing semantics in PyTorch, from NumPy, as documented here: https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#purely-integer-array-indexing. It should work with gets/sets and autograd, and allows for LongTensor indexers as well.

 x = torch.arange(0, 25).view(5, 5)

  0   1   2   3   4
  5   6   7   8   9
 10  11  12  13  14
 15  16  17  18  19
 20  21  22  23  24
[torch.FloatTensor of size 5x5]

x[[0,4,3], [1,3,0]]

  1
 23
 15
[torch.FloatTensor of size 3]

Begins to address #1080.

This comment was marked as off-topic.

This comment was marked as off-topic.

@killeent killeent force-pushed the advanced-indexing branch from ccfbaa6 to 9888e9a Compare May 23, 2017 15:46
@killeent killeent force-pushed the advanced-indexing branch from 9888e9a to 0518639 Compare June 12, 2017 16:32
@killeent killeent changed the title WIP: a subset of advanced indexing Advanced Indexing Part 1 -- Purely Integer Array Indexing Jun 12, 2017
@killeent
Copy link
Contributor Author

This is now in a reviewable state cc @fmassa.

This comment was marked as off-topic.

This comment was marked as off-topic.

@kentsommer
Copy link

Is this indexing supposed to work on tensors that have been pushed to the gpu with .cuda()?

If it is, I encountered this issue today while testing the indexing out:

(apytorch) kent@kbox ~/Documents/misc-git-repos $ python
Python 2.7.12 (default, Nov 19 2016, 06:48:10) 
[GCC 5.4.0 20160609] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import numpy as np
>>> x = torch.arange(0,80).view(1,10,8)
>>> x[[0], [1, 2, 3, 4, 5, 6, 7, 8, 9], [2]]

 10
 18
 26
 34
 42
 50
 58
 66
 74
[torch.FloatTensor of size 9]

>>> y = x.cuda()
>>> y[[0], [1, 2, 3, 4, 5, 6, 7, 8, 9], [2]]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: 208 is not a device at /home/kent/Documents/misc-git-repos/pytorch/torch/lib/THC/THCGeneral.c:206
>>> y[[0], [1, 2, 3, 4, 5, 6, 7, 8, 9], [2]]
THCudaCheck FAIL file=/home/kent/Documents/misc-git-repos/pytorch/torch/lib/THC/generic/THCStorage.c line=32 error=77 : an illegal memory access was encountered
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /home/kent/Documents/misc-git-repos/pytorch/torch/lib/THC/generic/THCStorage.c:32
>>>

@killeent
Copy link
Contributor Author

@kentsommer Yes in theory it should work for CUDA. Let me take a look.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a first pass and it looks good for the most part.
I think there is a problem with the autograd part that would be good to have fixed (in the worst case, you can have a check if the grad_output is volatile (meaning no grad of grad), and raise an error if not.
Also, for the more general case x[:, [1, 2], [2]], the linear indices might grow quite a bit, but we would need to check to see if that would have an impact or not.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@killeent
Copy link
Contributor Author

@kentsommer This should now work with CUDA. The issue was that we were trying to use CPU Tensors to index into the CUDA Tensor. Now we copy the indices onto the GPU before calling the index_* functions. cc @fmassa this means that all index calculation is being done on the CPU :).

I will take a look at grad of grad tomorrow.

@killeent killeent force-pushed the advanced-indexing branch from 4cfa35e to a808771 Compare June 22, 2017 17:20
@killeent killeent force-pushed the advanced-indexing branch from a808771 to 7794c92 Compare June 22, 2017 17:59
@killeent
Copy link
Contributor Author

killeent commented Jun 22, 2017

Okay, I have now moved the advanced_index_add call into Variable. This required adding a new function to _functions/tensor.pythat handles forwards and backwards for this method, including adding a new functionadvanced_index_select` for the backward's implementation. I don't understand enough about autograd to have confidence in this implementation, but it passes the minimal tests I have added.

cc @fmassa, @gchanan

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I glanced over autograd and it looks ok to me

if ctx.advanced_indexing:
grad_input._advanced_index_add(ctx.index, grad_output)
else:
grad_input[ctx.index] = grad_output

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

def consec(size, start=1):
sequence = torch.ones(int(torch.Tensor(size).prod(0)[0])).cumsum(0)
sequence.add_(start - 1)
return sequence.resize_(*size)

This comment was marked as off-topic.

This comment was marked as off-topic.

assert not ctx.needs_input_grad[1]
if ctx.needs_input_grad[2]:
ctx.adv_index = adv_index
ctx.mark_dirty(tensor1)

This comment was marked as off-topic.

This comment was marked as off-topic.

@soumith
Copy link
Collaborator

soumith commented Jun 22, 2017

this is now merged into master!

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 28, 2017

Just updated my PyTorch. x[[0,4,3], [1,3,0]] from the OP works, but x[[0,4,3]] doesn't. Is it ok?

x[[0,4,3], :] works as well as x[torch.LongTensor([0,4,3])].

>>> x[[0,4,3]]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: indexing a tensor with an object of type list. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.
>>> x[torch.LongTensor([0,4,3])]

  0   1   2   3   4
 20  21  22  23  24
 15  16  17  18  19
[torch.FloatTensor of size 3x5]

>>> x[[0,4,3], :]

  0   1   2   3   4
 20  21  22  23  24
 15  16  17  18  19
[torch.FloatTensor of size 3x5]

@soumith
Copy link
Collaborator

soumith commented Jun 28, 2017

yes it's a todo in the code. will be fixed in a next PR

@ngimel
Copy link
Collaborator

ngimel commented Jun 29, 2017

@killeent, I have a cuda test failing when I build from master, python test_cuda.py TestCuda.test_advancedindex. It's failing in assert_set_eq when the indexer has dupes, and I've added some printouts there:

tensor
  0   1   2   3   4
  5   6   7   8   9
 10  11  12  13  14
 15  16  17  18  19
indexer [slice(None, None, None), [0, 1, 1, 2, 2]]
val
 16  12   0  14  10
 19  17   2   7   4
  5   9   6   1  15
 13  11   3   8  18
[torch.cuda.DoubleTensor of size 4x5 (GPU 0)]

pyt 
 16  12  14   3   4
 19  17   7   8   9
  5   9   1  13  14
 13  11   8  18  19
[torch.cuda.DoubleTensor of size 4x5 (GPU 0)]
 numt 
 16   0  10   3   4
 19   2   4   8   9
  5   6  15  13  14
 13   3  18  18  19

Apparently, in the correct output (numt) the second of the duped columns should be copied to tensor from val, and in my failing case (pyt) the first one is. Can you point me to where in the backend this is happening? It's passing on most architectures, so that's not an outright bug, but failing on one.

@vadimkantorov
Copy link
Contributor

In addition to x[[0,4,3]], torch.arange(0, 25)[[0,4,3]] doesn't work either, except the workaround torch.arange(0, 25)[[0,4,3], :] isn't possible for the lack of dimensions. (I guess this one may exist on TODO as well, but since stumbled upon it, will leave it here for completeness).

houseroad added a commit to houseroad/pytorch that referenced this pull request Nov 20, 2018
…fb74b7

Summary:
Previous import was 882c5283c54345d131e8fe5c859e4844dcf7ca8e

Included changes:
- **[45ba661](onnx/onnx@45ba661)**: Handle new types in the switch. (pytorch#1608) <Dmitri Smirnov>
- **[14853b6](onnx/onnx@14853b6)**: Bump docker image version to 230 used in CircleCI (pytorch#1606) <bddppq>
- **[e0993b8](onnx/onnx@e0993b8)**: [onnxifi] Make sure that backend handles run async. (pytorch#1599) <Roman Dzhabarov>
- **[e6965cc](onnx/onnx@e6965cc)**: Introduce SparseTensor ML proto (pytorch#1554) <Dmitri Smirnov>
- **[75b782f](onnx/onnx@75b782f)**: In driver test check the return status of onnxGetBackendIDs (pytorch#1597) <bddppq>
- **[c05b364](onnx/onnx@c05b364)**: Make CI log less verbose (pytorch#1595) <bddppq>
- **[fa568e4](onnx/onnx@fa568e4)**: Loop type shape inferencing (pytorch#1591) <Scott McKay>
- **[937e64c](onnx/onnx@937e64c)**: add uint8 (pytorch#1590) <Lu Fang>
- **[f86e951](onnx/onnx@f86e951)**: Add domain as an optional parameter for make_node function (pytorch#1588) <Young Kim>
- **[ff45588](onnx/onnx@ff45588)**: Remove unreachable code in shape_inference.h (pytorch#1585) <Changming Sun>
- **[f7dcad0](onnx/onnx@f7dcad0)**: Add several hyperbolic function ops. (pytorch#1499) <Sergii Dymchenko>
- **[a60ac7d](onnx/onnx@a60ac7d)**: Add OneHot op to ONNX. (pytorch#1567) <Spandan Tiwari>
- **[f6c3a7e](onnx/onnx@f6c3a7e)**: [compiler flag] Issue a warning if class has virtual method but missing virtual dtor. (pytorch#1583) <Roman Dzhabarov>
- **[88d1784](onnx/onnx@88d1784)**: Fix MaxUnpool shape inference when output_shape is provided as input (pytorch#1578) <Spandan Tiwari>
- **[20041b7](onnx/onnx@20041b7)**: Add type shape inferencing for the If operator (pytorch#1571) <Scott McKay>
- **[d6c4c75](onnx/onnx@d6c4c75)**: Add a virtual destructor to GraphInferencer (pytorch#1574) <Changming Sun>
- **[a339598](onnx/onnx@a339598)**: fix ConvTranspose spec (pytorch#1566) <Wenhao Hu>

Differential Revision: D13049077

fbshipit-source-id: 11133f10bc6b451094d1081e4ce736b02c8b9e2a
houseroad added a commit to houseroad/pytorch that referenced this pull request Nov 29, 2018
…002d19

Summary:
Previous import was 882c5283c54345d131e8fe5c859e4844dcf7ca8e

Included changes:
- **[f461f7a](onnx/onnx@f461f7a)**: Show the op's type and name when the shape inference is failed. (pytorch#1623) <Jerry>
- **[ab8aaf9](onnx/onnx@ab8aaf9)**: Add scan test case (pytorch#1586) <G. Ramalingam>
- **[c95357e](onnx/onnx@c95357e)**: link the tutorial (pytorch#1650) <Lu Fang>
- **[d7e2420](onnx/onnx@d7e2420)**: Upgrade label encoder to support more input types (pytorch#1596) <Wei-Sheng Chin>
- **[6425108](onnx/onnx@6425108)**: Add Doc about Adding New Operator into ONNX (pytorch#1647) <Lu Fang>
- **[295889c](onnx/onnx@295889c)**: use an empty initializer to create map (pytorch#1643) <Lu Fang>
- **[e38f3ec](onnx/onnx@e38f3ec)**: Remove redundant const (pytorch#1639) <daquexian>
- **[ea694bf](onnx/onnx@ea694bf)**: implement fuse reduce->unsqueeze + fix assumption in nop_dropout pass (pytorch#1565) <Armen>
- **[6db386e](onnx/onnx@6db386e)**: make output shape clear enough for Softmax family (pytorch#1634) <Lu Fang>
- **[2b67c6e](onnx/onnx@2b67c6e)**: fix batchnorm doc (pytorch#1633) <Lu Fang>
- **[c901784](onnx/onnx@c901784)**: remove inappropriate consts (pytorch#1632) <Lu Fang>
- **[de82119](onnx/onnx@de82119)**: Shape inference fix for broadcast, concat and scan (pytorch#1594) <KeDengMS>
- **[d7ffe3b](onnx/onnx@d7ffe3b)**: Update Optimizer Docs (pytorch#1607) <Armen>
- **[d09d139](onnx/onnx@d09d139)**: mark PROTOBUF_INCLUDE_DIRS as BUILD_INTERFACE (pytorch#1466) <Yuta Okamoto>
- **[eb4b7c2](onnx/onnx@eb4b7c2)**: allow variadic parameters of different types (pytorch#1615) <G. Ramalingam>
- **[4166246](onnx/onnx@4166246)**: Fix onnxifi test (pytorch#1617) <Yinghai Lu>
- **[6706a4d](onnx/onnx@6706a4d)**: Fix a bug in vector address access (pytorch#1598) <Raymond Yang>
- **[ae39866](onnx/onnx@ae39866)**: Separate types of inputs 1 and 2 in OneHot op. (pytorch#1610) <Spandan Tiwari>
- **[45ba661](onnx/onnx@45ba661)**: Handle new types in the switch. (pytorch#1608) <Dmitri Smirnov>
- **[14853b6](onnx/onnx@14853b6)**: Bump docker image version to 230 used in CircleCI (pytorch#1606) <bddppq>
- **[e0993b8](onnx/onnx@e0993b8)**: [onnxifi] Make sure that backend handles run async. (pytorch#1599) <Roman Dzhabarov>
- **[e6965cc](onnx/onnx@e6965cc)**: Introduce SparseTensor ML proto (pytorch#1554) <Dmitri Smirnov>
- **[75b782f](onnx/onnx@75b782f)**: In driver test check the return status of onnxGetBackendIDs (pytorch#1597) <bddppq>
- **[c05b364](onnx/onnx@c05b364)**: Make CI log less verbose (pytorch#1595) <bddppq>
- **[fa568e4](onnx/onnx@fa568e4)**: Loop type shape inferencing (pytorch#1591) <Scott McKay>
- **[937e64c](onnx/onnx@937e64c)**: add uint8 (pytorch#1590) <Lu Fang>
- **[f86e951](onnx/onnx@f86e951)**: Add domain as an optional parameter for make_node function (pytorch#1588) <Young Kim>
- **[ff45588](onnx/onnx@ff45588)**: Remove unreachable code in shape_inference.h (pytorch#1585) <Changming Sun>
- **[f7dcad0](onnx/onnx@f7dcad0)**: Add several hyperbolic function ops. (pytorch#1499) <Sergii Dymchenko>
- **[a60ac7d](onnx/onnx@a60ac7d)**: Add OneHot op to ONNX. (pytorch#1567) <Spandan Tiwari>
- **[f6c3a7e](onnx/onnx@f6c3a7e)**: [compiler flag] Issue a warning if class has virtual method but missing virtual dtor. (pytorch#1583) <Roman Dzhabarov>
- **[88d1784](onnx/onnx@88d1784)**: Fix MaxUnpool shape inference when output_shape is provided as input (pytorch#1578) <Spandan Tiwari>
- **[20041b7](onnx/onnx@20041b7)**: Add type shape inferencing for the If operator (pytorch#1571) <Scott McKay>
- **[d6c4c75](onnx/onnx@d6c4c75)**: Add a virtual destructor to GraphInferencer (pytorch#1574) <Changming Sun>
- **[a339598](onnx/onnx@a339598)**: fix ConvTranspose spec (pytorch#1566) <Wenhao Hu>

Differential Revision: D13263831

fbshipit-source-id: 0c158dd12c45d704b6f37f63f3d74ed34ef2f534
facebook-github-bot pushed a commit that referenced this pull request Nov 30, 2018
…002d19 (#14568)

Summary:
Pull Request resolved: #14568

Previous import was 882c5283c54345d131e8fe5c859e4844dcf7ca8e

Included changes:
- **[f461f7a](onnx/onnx@f461f7a)**: Show the op's type and name when the shape inference is failed. (#1623) <Jerry>
- **[ab8aaf9](onnx/onnx@ab8aaf9)**: Add scan test case (#1586) <G. Ramalingam>
- **[c95357e](onnx/onnx@c95357e)**: link the tutorial (#1650) <Lu Fang>
- **[d7e2420](onnx/onnx@d7e2420)**: Upgrade label encoder to support more input types (#1596) <Wei-Sheng Chin>
- **[6425108](onnx/onnx@6425108)**: Add Doc about Adding New Operator into ONNX (#1647) <Lu Fang>
- **[295889c](onnx/onnx@295889c)**: use an empty initializer to create map (#1643) <Lu Fang>
- **[e38f3ec](onnx/onnx@e38f3ec)**: Remove redundant const (#1639) <daquexian>
- **[ea694bf](onnx/onnx@ea694bf)**: implement fuse reduce->unsqueeze + fix assumption in nop_dropout pass (#1565) <Armen>
- **[6db386e](onnx/onnx@6db386e)**: make output shape clear enough for Softmax family (#1634) <Lu Fang>
- **[2b67c6e](onnx/onnx@2b67c6e)**: fix batchnorm doc (#1633) <Lu Fang>
- **[c901784](onnx/onnx@c901784)**: remove inappropriate consts (#1632) <Lu Fang>
- **[de82119](onnx/onnx@de82119)**: Shape inference fix for broadcast, concat and scan (#1594) <KeDengMS>
- **[d7ffe3b](onnx/onnx@d7ffe3b)**: Update Optimizer Docs (#1607) <Armen>
- **[d09d139](onnx/onnx@d09d139)**: mark PROTOBUF_INCLUDE_DIRS as BUILD_INTERFACE (#1466) <Yuta Okamoto>
- **[eb4b7c2](onnx/onnx@eb4b7c2)**: allow variadic parameters of different types (#1615) <G. Ramalingam>
- **[4166246](onnx/onnx@4166246)**: Fix onnxifi test (#1617) <Yinghai Lu>
- **[6706a4d](onnx/onnx@6706a4d)**: Fix a bug in vector address access (#1598) <Raymond Yang>
- **[ae39866](onnx/onnx@ae39866)**: Separate types of inputs 1 and 2 in OneHot op. (#1610) <Spandan Tiwari>
- **[45ba661](onnx/onnx@45ba661)**: Handle new types in the switch. (#1608) <Dmitri Smirnov>
- **[14853b6](onnx/onnx@14853b6)**: Bump docker image version to 230 used in CircleCI (#1606) <bddppq>
- **[e0993b8](onnx/onnx@e0993b8)**: [onnxifi] Make sure that backend handles run async. (#1599) <Roman Dzhabarov>
- **[e6965cc](onnx/onnx@e6965cc)**: Introduce SparseTensor ML proto (#1554) <Dmitri Smirnov>
- **[75b782f](onnx/onnx@75b782f)**: In driver test check the return status of onnxGetBackendIDs (#1597) <bddppq>
- **[c05b364](onnx/onnx@c05b364)**: Make CI log less verbose (#1595) <bddppq>
- **[fa568e4](onnx/onnx@fa568e4)**: Loop type shape inferencing (#1591) <Scott McKay>
- **[937e64c](onnx/onnx@937e64c)**: add uint8 (#1590) <Lu Fang>
- **[f86e951](onnx/onnx@f86e951)**: Add domain as an optional parameter for make_node function (#1588) <Young Kim>
- **[ff45588](onnx/onnx@ff45588)**: Remove unreachable code in shape_inference.h (#1585) <Changming Sun>
- **[f7dcad0](onnx/onnx@f7dcad0)**: Add several hyperbolic function ops. (#1499) <Sergii Dymchenko>
- **[a60ac7d](onnx/onnx@a60ac7d)**: Add OneHot op to ONNX. (#1567) <Spandan Tiwari>
- **[f6c3a7e](onnx/onnx@f6c3a7e)**: [compiler flag] Issue a warning if class has virtual method but missing virtual dtor. (#1583) <Roman Dzhabarov>
- **[88d1784](onnx/onnx@88d1784)**: Fix MaxUnpool shape inference when output_shape is provided as input (#1578) <Spandan Tiwari>
- **[20041b7](onnx/onnx@20041b7)**: Add type shape inferencing for the If operator (#1571) <Scott McKay>
- **[d6c4c75](onnx/onnx@d6c4c75)**: Add a virtual destructor to GraphInferencer (#1574) <Changming Sun>
- **[a339598](onnx/onnx@a339598)**: fix ConvTranspose spec (#1566) <Wenhao Hu>

Reviewed By: zrphercule

Differential Revision: D13263831

fbshipit-source-id: a2ff22c6454e2430429e5a7d18d21661a7ffb0cb
zasdfgbnm pushed a commit that referenced this pull request Apr 26, 2022
…er (#1588)

Fixes #75622

1. Instead of getting max/min_value for reduction init value, we go with (-)infinity instead so we can properly preserve inf inputs;
2. Adding inf/(-)inf/nan for float value.
3. Adding aten::amin in nvfuser (@kevinstephano @rdspring1 for review)
Pull Request resolved: #75646
Approved by: https://github.com/rdspring1, https://github.com/kevinstephano, https://github.com/ngimel
jagadish-amd pushed a commit to jagadish-amd/pytorch that referenced this pull request Oct 24, 2024
pytorch#1588)

…d in reduce config (pytorch#135397)

Fixes pytorch#132964

This change is to optimize torch.sum() performance by increasing
max_values_per_thread in setReduceConfig() for ROCm platform. By
increasing this parameter, it uses fewer threadblocks and improved the
performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to
3205GByte/s from ~1690GByte/s for the test case and is slightly better
than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf
improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: pytorch#135397
Approved by: https://github.com/eqy, https://github.com/malfet

Fixes #ISSUE_NUMBER

Co-authored-by: hongxyan <hongxyan@amd.com>
jagadish-amd pushed a commit to jagadish-amd/pytorch that referenced this pull request Jan 14, 2025
pytorch#1588)

…d in reduce config (pytorch#135397)

Fixes pytorch#132964

This change is to optimize torch.sum() performance by increasing
max_values_per_thread in setReduceConfig() for ROCm platform. By
increasing this parameter, it uses fewer threadblocks and improved the
performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to
3205GByte/s from ~1690GByte/s for the test case and is slightly better
than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf
improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: pytorch#135397
Approved by: https://github.com/eqy, https://github.com/malfet

Fixes #ISSUE_NUMBER

Co-authored-by: hongxyan <hongxyan@amd.com>
(cherry picked from commit 4360582)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants