Fix index truncation in argmin/max for large tensors#33310
Fix index truncation in argmin/max for large tensors#33310peterbell10 wants to merge 3 commits intopytorch:masterfrom
Conversation
09bb2b3 to
f2b4178
Compare
💊 CircleCI build failures summary and remediationsAs of commit c26c2f6:
Detailed failure analysisOne may explore the probable reasons each build failed interactively on the Dr. CI website. 🕵️ 1 new failure recognized by patternsThe following build failures do not appear to be due to upstream breakage:
|
|
nifty. @VitalyFedyunin do you think you can take full reviewership on this or should I look? |
|
Or maybe Natalia is the best reviewer here. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Test failure looks genuine, but for some reason it's the CPU version that failed: |
ngimel
left a comment
There was a problem hiding this comment.
Approving, modulo failing test. Do you need help with debugging it?
|
It looks like there are places where the cpu reduction also splits iterators. I'll see if the same fix will help for there. |
4fe198f to
c26c2f6
Compare
|
I applied the fix to the CPU reduction but it turns out the actual cause here was just that I was assigning |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
I hope #33370 doesn't intervene with this one |
|
#33370 is unrelated. |
|
Are there any consequences about this mixed 64bit/32bit indexing? e.g. is it faster to do argmax across a contiguous small chunk vs a discontiguous small chunk that may require 64bit indexing? |
This PR only introduces 1 extra add per reduced element in each kernel, so I don't expect performance to change much.
There is a potential cost from splitting the operation over multiple kernel launches; it will limit the parallelism available within each kernel. For large contiguous chunks, that's inconsequential but, in the extreme, I suppose a slice However, I'd note that this has always been the case for |
|
So in theory there is tradeoff between multiple kernel launches that do 32-bit indexing and a single kernel launch that must do 64-bit indexing? |
|
In theory, yes. In practice I don't think it matters much. 32-bit indexing covers ~2GB of memory so you could only have 10-20 tensor splits before you've covered all the memory in a top of the line Quadro card. |
Summary: Fixes #32863, (together with #33310 for the `TensorIterator` reductions) This adds 64-bit indexed kernels for `THC_reduceDimIndex` and uses `THCTensor_canUse32BitIndexMath` to switch between the two at runtime. I have a test for this locally but haven't included it here because `max` is much slower than `argmax`. To the point where the test takes several minutes to call max on just one `2**32` element tensor. That seems excessive, even for a slow test but I can push it if preferred. Pull Request resolved: #33405 Differential Revision: D20010769 Pulled By: ezyang fbshipit-source-id: a8a86f662598d5fade4d90448436418422c699a3
Summary: Fixes the `TensorIterator` parts of pytorch#32863 (THC is still broken) `TensorIterator::split` now keeps track of the `view_offsets` into the full tensor range. With this, I can take the base offset for the reduced dimension and translate partial results from the sub-iter into the index range of the full tensor. This happens only once for each intermediate result, so we should still benefit from the performance of 32-bit indexing in loops. Pull Request resolved: pytorch#33310 Differential Revision: D19906136 Pulled By: ngimel fbshipit-source-id: 3372ee4b8d5b115a53be79aeafc52e80ff9c490b
Summary: Fixes pytorch#32863, (together with pytorch#33310 for the `TensorIterator` reductions) This adds 64-bit indexed kernels for `THC_reduceDimIndex` and uses `THCTensor_canUse32BitIndexMath` to switch between the two at runtime. I have a test for this locally but haven't included it here because `max` is much slower than `argmax`. To the point where the test takes several minutes to call max on just one `2**32` element tensor. That seems excessive, even for a slow test but I can push it if preferred. Pull Request resolved: pytorch#33405 Differential Revision: D20010769 Pulled By: ezyang fbshipit-source-id: a8a86f662598d5fade4d90448436418422c699a3
Summary: Closes gh-39060 The `TensorIterator` splitting is based on `can_use_32bit_indexing` which assumes 32-bit signed ints, so we can get away with just 2**31 as the axis length. Also tested on an old commit that I can reproduce the test failure on just a 1d tensor, overall quartering the memory requirement for the test. https://github.com/pytorch/pytorch/blob/4c7d81f8479bce320cc11d1eb3adaf8ab0b90099/aten/src/ATen/native/TensorIterator.cpp#L879 For reference, the test was first added in gh-33310. Pull Request resolved: #40036 Differential Revision: D22068690 Pulled By: ezyang fbshipit-source-id: 83199fd31647d1ef106b08f471c0e9517d3516e3
Summary: Closes pytorchgh-39060 The `TensorIterator` splitting is based on `can_use_32bit_indexing` which assumes 32-bit signed ints, so we can get away with just 2**31 as the axis length. Also tested on an old commit that I can reproduce the test failure on just a 1d tensor, overall quartering the memory requirement for the test. https://github.com/pytorch/pytorch/blob/4c7d81f8479bce320cc11d1eb3adaf8ab0b90099/aten/src/ATen/native/TensorIterator.cpp#L879 For reference, the test was first added in pytorchgh-33310. Pull Request resolved: pytorch#40036 Differential Revision: D22068690 Pulled By: ezyang fbshipit-source-id: 83199fd31647d1ef106b08f471c0e9517d3516e3
Fixes the
TensorIteratorparts of #32863 (THC is still broken)TensorIterator::splitnow keeps track of theview_offsetsinto the full tensor range. With this, I can take the base offset for the reduced dimension and translate partial results from the sub-iter into the index range of the full tensor. This happens only once for each intermediate result, so we should still benefit from the performance of 32-bit indexing in loops.