fix _clone_meta stride computation for torch.preserve_format#161400
fix _clone_meta stride computation for torch.preserve_format#161400morrison-turnansky wants to merge 4 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161400
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit ba184c0 with merge base 69a25f6 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "topic: not user facing" |
|
I was directed to include you for review. Thank you in advance, @zou3519 |
zou3519
left a comment
There was a problem hiding this comment.
Thank you for the PR, I think the semantics of preserve_format are a bit more complicated. Could you take a look please?
28a91d8 to
4b0a80a
Compare
4b0a80a to
70ca3c0
Compare
|
@zou3519 I updated behavior. Following docs I separated the case for dense and non-dense cases. I also added a test for each case. Please let me know if you would like any additional changes. |
| if torch._prims_common.is_non_overlapping_and_dense(input): | ||
| strides = input.stride() | ||
| else: | ||
| strides = input.contiguous().stride() |
There was a problem hiding this comment.
I'm not completely sure this is correct. Also, assuming that it is correct, we should update utils.compute_elementwise_output_strides -- other operators run into the same problem.
This is the function we use to compute strides when the input is not "non_overlapping_and_dense": https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/TensorIterator.cpp#L1276.
If we think it is exactly input.contiguous.stride(), then we should add some unittests to check that
There was a problem hiding this comment.
EDIT: .contiguous calls exactly that function so it is actually the correct thing. So my ask is that we should update utils.compute_elementwise_output_strides to call
if torch._prims_common.is_non_overlapping_and_dense(input):
strides = input.stride()
else:
strides = input.contiguous().stride()There was a problem hiding this comment.
@zou3519 Thank you for checking the correct behavior. I updated utils.compute_elementwise_output_strides to call
if torch._prims_common.is_non_overlapping_and_dense(input):
strides = input.stride()
else:
strides = input.contiguous().stride()exactly when 1 tensor is given.
For the case of multiple tensors, it was not clear how to adapt this without more substantial changes. There is quite a bit of logic in compute_elementwise_output_logical_to_physical_perm. Let me know if you want me to update the behavior there as well.
There was a problem hiding this comment.
No, let's fix the one tensor case first.
There was a problem hiding this comment.
After we get this merged, I can look into the multiple tensor case if you'd like.
…tensor when exactly one tensor is given
5d36a87 to
ba184c0
Compare
| if len(tensors) == 1: | ||
| if torch._prims_common.is_non_overlapping_and_dense(tensors[0]): | ||
| return tensors[0].stride() | ||
| else: | ||
| return tensors[0].contiguous().stride() |
There was a problem hiding this comment.
@laithsakka @bobrenjc93 any dynamic shapes issues around implementing this like this? Otherwise, if the tests pass, I will assume we can ship this.
There was a problem hiding this comment.
I took a look at failing tests, still figuring it out. It is an issue that was introduced by the change with inductor.
def test_clone_not_memory_dense():
def foo():
x = torch.randn(10, 8).t()[::2, ::2]
y = x.clone()
return y
y = foo()
assert y.stride() == (1, 4)
print("uncompiled")
y = torch.compile(foo, backend="eager")()
print("eager")
assert y.stride() == (1, 4)
y = torch.compile(foo, backend="aot_eager")()
print("aot_eager")
assert y.stride() == (1, 4)
y = torch.compile(foo, backend="inductor")()
print("inductor")
print(y.stride())
assert y.stride() == (1, 4)|
Hi @morrison-turnansky thank you for the contribution! Since this issue is UBN we would like to action on it ASAP, I have worked on top of your PR to generalize and try and address this failure - see #163017. Let's monitor signals there to see if that generalization works or has other unexpected failures and push this across the finishline :) |
|
@Lucaskabela Looks like the CI is passing on your new PR. Thanks for the help on this. |
|
Of course, and thank you for all the work on this! I have your commits included in that PR for attribution, and it was extremely valuable as you have already implemented the matching semantics for len 1 case here :) |
| return () | ||
|
|
||
| if len(tensors) == 1: | ||
| if torch._prims_common.is_non_overlapping_and_dense(tensors[0]): |
There was a problem hiding this comment.
we prob should rename def is_non_overlapping_and_dense to
is_non_overlapping_and_dense_or_false. its fine its reasonable definition of unbacked semantics its probably would be good to ensure inductor behave similarly with unbacked inputs.
There was a problem hiding this comment.
@laithsakka Thank you for the note, these commits were merged on PR, #163017. However that PR did not fully capture all the correct behavior. When the future work is implemented, I will make a note of this preserved naming scheme.
Continuation of work from #161400 and #163017. Updating stride semantics for ```clone_meta``` and underlying function, ```compute_elementwise_output_strides```. Pull Request resolved: #164252 Approved by: https://github.com/Lucaskabela
Continuation of work from pytorch#161400 and pytorch#163017. Updating stride semantics for ```clone_meta``` and underlying function, ```compute_elementwise_output_strides```. Pull Request resolved: pytorch#164252 Approved by: https://github.com/Lucaskabela
Fixes #161010
fixed stride issue for cloning meta tensor