[primTorch] Adds contiguous and expand references#79820
[primTorch] Adds contiguous and expand references#79820
Conversation
🔗 Helpful links
❌ 2 New FailuresAs of commit a5ac174 (more details on the Dr. CI page): Expand to see more
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
torch/_prims/utils.py
Outdated
| ) | ||
|
|
||
| shape = list(shape) | ||
| shape[1], shape[-1] = shape[-1], shape[1] |
There was a problem hiding this comment.
this is not correct, you need to permute shape elements like
shape[1], shape[2], shape[3] = shape[2], shape[3], shape[1]
(and how are tests passing? even for (2,2,2,2) they shouldn't. Oh, I guess it's because tests actually don't test strides)
torch/_prims/utils.py
Outdated
|
|
||
| shape = list(shape) | ||
| shape[1], shape[-1] = shape[-1], shape[1] | ||
| strides = list(make_contiguous_strides_for(shape)) |
There was a problem hiding this comment.
there's this weird inconsistency in eager for 0-element tensors where contiguous leaves the strides of following size-0 dimensions to be whatever the stride of the previous dimension was, whereas for channels-last that stride (and all subsequent) are zeroed:
In [41]: a=torch.empty(2,2,2,0, memory_format=torch.channels_last)
In [42]: a.stride()
Out[42]: (0, 1, 0, 2)
In [43]: a=torch.empty(2,2,2,0)
In [44]: a.stride()
Out[44]: (4, 2, 1, 1)
so using this approach would produce results that are different from eager.
Another eager inconsistency is that empty tensors are always contiguous, but have to follow precise striding requirements to be channels_last contiguous, and we should probably fix that.
There was a problem hiding this comment.
Great catch -- I redid the implementation to capture this caveat and I'll add some additional testing
torch/_prims/utils.py
Outdated
|
|
||
| # Short-circuits for tensors with zero or one elements, which | ||
| # are trivially non-overlapping and "dense" | ||
| if a.numel() < 2: |
There was a problem hiding this comment.
do we need this special case, or does it follow from is_contiguous=True?
There was a problem hiding this comment.
Yep, we can remove this special case
torch/_prims/utils.py
Outdated
| ) | ||
|
|
||
| shape = list(shape) | ||
| shape[1], shape[-1] = shape[-1], shape[1] |
There was a problem hiding this comment.
same here, (and same for strides in both places, modulo inverse permutation)
|
|
||
| # Short-circuits for tensors of rank one, which are | ||
| # non-overlapping and "dense" if their stride is one | ||
| if a.ndim == 1: |
There was a problem hiding this comment.
this also follows from is_contiguous(a)=True so you'd never get here
torch/_prims/utils.py
Outdated
| strides = [] | ||
| for idx in (-1, -2, 0): | ||
| l = shape[idx] | ||
| if l != 0: |
There was a problem hiding this comment.
if and else branches are the same?
torch/_prims/utils.py
Outdated
| multiplier = shape[1] | ||
| strides = [] | ||
| for idx in (-1, -2, 0): | ||
| l = shape[idx] | ||
| if l != 0: | ||
| strides.append(multiplier) | ||
| else: | ||
| strides.append(multiplier) | ||
| # NOTE: intentionally divergence from make_contiguous_strides_for | ||
| # This is consistent with eager | ||
| multiplier = l * multiplier |
There was a problem hiding this comment.
| multiplier = shape[1] | |
| strides = [] | |
| for idx in (-1, -2, 0): | |
| l = shape[idx] | |
| if l != 0: | |
| strides.append(multiplier) | |
| else: | |
| strides.append(multiplier) | |
| # NOTE: intentionally divergence from make_contiguous_strides_for | |
| # This is consistent with eager | |
| multiplier = l * multiplier | |
| multiplier = 1 | |
| strides = [0]*4 | |
| for idx in (1, -1, -2, 0): | |
| # NOTE: intentionally divergence from make_contiguous_strides_for | |
| # This is consistent with eager | |
| strides[idx]=multiplier | |
| multiplier *= shape[idx] |
torch/_prims/utils.py
Outdated
| strides.insert(2, 1) | ||
| result = tuple(reversed(strides)) | ||
| return result |
There was a problem hiding this comment.
| strides.insert(2, 1) | |
| result = tuple(reversed(strides)) | |
| return result | |
| return strides |
| yield SampleInput(make_arg((6, 6), noncontiguous=True)) | ||
|
|
||
| # channels last 2D | ||
| yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last}) |
There was a problem hiding this comment.
these are the same as clone?
There was a problem hiding this comment.
Yep -- combining them -- great catch
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @mruberry. |
Summary: I also filed while creating this PR. This PR... **Filed issues** - #79818 - #80154 **prims** - Fixes prims.squeeze when called with an unsorted list of dimensions - Removes the clone prim **refs** - adds contiguous - adds expand - updates clone to call empty_like and copy_to - updates empty to accept a memory format - updates empty_like to accept a memory_format **utils** - adds helper functions for working with memory formats and channels last tensors, in particular **tests** - removes unused clamp sample input functions (mooted by clamp's new reference inputs) - extends the reference inputs for clone to include different memory formats - creates reference inputs for contiguous - xfails operators that depend on clone (including clone) on `test_python_ref` (see issues) Pull Request resolved: #79820 Approved by: https://github.com/ngimel Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/8740c68c41c563de42b011cef42b3de7690e9446 Reviewed By: DanilBaibak Differential Revision: D37781899 Pulled By: mruberry fbshipit-source-id: 7c422f216844ac9c3c95ccc97d320faaafd88b0b
Add Reference: - nll_loss Depends on: - expand #79820 - advance indexing Pull Request resolved: #81128 Approved by: https://github.com/mruberry
I also filed while creating this PR.
This PR...
Filed issues
test_python_ref_executoroften fails due to concrete arg mismatch #79818prims
refs
utils
tests
test_python_ref(see issues)