Skip to content

fix _clone_meta stride computation for torch.preserve_format#161400

Closed
morrison-turnansky wants to merge 4 commits intopytorch:mainfrom
morrison-turnansky:issue-161010-dynamo-stride-clone
Closed

fix _clone_meta stride computation for torch.preserve_format#161400
morrison-turnansky wants to merge 4 commits intopytorch:mainfrom
morrison-turnansky:issue-161010-dynamo-stride-clone

Conversation

@morrison-turnansky
Copy link
Copy Markdown
Collaborator

Fixes #161010

fixed stride issue for cloning meta tensor

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Aug 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161400

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit ba184c0 with merge base 69a25f6 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@morrison-turnansky
Copy link
Copy Markdown
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot Bot added the topic: not user facing topic category label Aug 25, 2025
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 27, 2025
@morrison-turnansky
Copy link
Copy Markdown
Collaborator Author

I was directed to include you for review. Thank you in advance, @zou3519

zou3519
zou3519 previously requested changes Sep 2, 2025
Copy link
Copy Markdown
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR, I think the semantics of preserve_format are a bit more complicated. Could you take a look please?

@morrison-turnansky morrison-turnansky force-pushed the issue-161010-dynamo-stride-clone branch 3 times, most recently from 28a91d8 to 4b0a80a Compare September 4, 2025 18:49
@morrison-turnansky morrison-turnansky force-pushed the issue-161010-dynamo-stride-clone branch from 4b0a80a to 70ca3c0 Compare September 4, 2025 19:08
@morrison-turnansky
Copy link
Copy Markdown
Collaborator Author

@zou3519 I updated behavior. Following docs I separated the case for dense and non-dense cases. I also added a test for each case. Please let me know if you would like any additional changes.

Comment thread torch/_prims/__init__.py Outdated
Comment on lines +697 to +700
if torch._prims_common.is_non_overlapping_and_dense(input):
strides = input.stride()
else:
strides = input.contiguous().stride()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely sure this is correct. Also, assuming that it is correct, we should update utils.compute_elementwise_output_strides -- other operators run into the same problem.

This is the function we use to compute strides when the input is not "non_overlapping_and_dense": https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/TensorIterator.cpp#L1276.

If we think it is exactly input.contiguous.stride(), then we should add some unittests to check that

Copy link
Copy Markdown
Contributor

@zou3519 zou3519 Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: .contiguous calls exactly that function so it is actually the correct thing. So my ask is that we should update utils.compute_elementwise_output_strides to call

if torch._prims_common.is_non_overlapping_and_dense(input):
    strides = input.stride()
else:
    strides = input.contiguous().stride()

Copy link
Copy Markdown
Collaborator Author

@morrison-turnansky morrison-turnansky Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Thank you for checking the correct behavior. I updated utils.compute_elementwise_output_strides to call

if torch._prims_common.is_non_overlapping_and_dense(input):
    strides = input.stride()
else:
    strides = input.contiguous().stride()

exactly when 1 tensor is given.
For the case of multiple tensors, it was not clear how to adapt this without more substantial changes. There is quite a bit of logic in compute_elementwise_output_logical_to_physical_perm. Let me know if you want me to update the behavior there as well.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, let's fix the one tensor case first.

Copy link
Copy Markdown
Collaborator Author

@morrison-turnansky morrison-turnansky Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After we get this merged, I can look into the multiple tensor case if you'd like.

@morrison-turnansky morrison-turnansky force-pushed the issue-161010-dynamo-stride-clone branch from 5d36a87 to ba184c0 Compare September 12, 2025 16:47
Comment on lines +658 to +662
if len(tensors) == 1:
if torch._prims_common.is_non_overlapping_and_dense(tensors[0]):
return tensors[0].stride()
else:
return tensors[0].contiguous().stride()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@laithsakka @bobrenjc93 any dynamic shapes issues around implementing this like this? Otherwise, if the tests pass, I will assume we can ship this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look at failing tests, still figuring it out. It is an issue that was introduced by the change with inductor.

def test_clone_not_memory_dense():
    def foo():
        x = torch.randn(10, 8).t()[::2, ::2]
        y = x.clone()
        return y
    y = foo()
    assert y.stride() == (1, 4)
    print("uncompiled")
    y = torch.compile(foo, backend="eager")()
    print("eager")
    assert y.stride() == (1, 4)
    y = torch.compile(foo, backend="aot_eager")()
    print("aot_eager")
    assert y.stride() == (1, 4)
    y = torch.compile(foo, backend="inductor")()
    print("inductor")
    print(y.stride())
    assert y.stride() == (1, 4)

@zou3519 zou3519 dismissed their stale review September 12, 2025 17:07

outdated

@Lucaskabela
Copy link
Copy Markdown
Contributor

Hi @morrison-turnansky thank you for the contribution! Since this issue is UBN we would like to action on it ASAP, I have worked on top of your PR to generalize and try and address this failure - see #163017. Let's monitor signals there to see if that generalization works or has other unexpected failures and push this across the finishline :)

@morrison-turnansky
Copy link
Copy Markdown
Collaborator Author

@Lucaskabela Looks like the CI is passing on your new PR. Thanks for the help on this.

@Lucaskabela
Copy link
Copy Markdown
Contributor

Of course, and thank you for all the work on this! I have your commits included in that PR for attribution, and it was extremely valuable as you have already implemented the matching semantics for len 1 case here :)

return ()

if len(tensors) == 1:
if torch._prims_common.is_non_overlapping_and_dense(tensors[0]):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we prob should rename def is_non_overlapping_and_dense to
is_non_overlapping_and_dense_or_false. its fine its reasonable definition of unbacked semantics its probably would be good to ensure inductor behave similarly with unbacked inputs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@laithsakka Thank you for the note, these commits were merged on PR, #163017. However that PR did not fully capture all the correct behavior. When the future work is implemented, I will make a note of this preserved naming scheme.

@morrison-turnansky morrison-turnansky deleted the issue-161010-dynamo-stride-clone branch October 3, 2025 17:42
pytorchmergebot pushed a commit that referenced this pull request Jan 6, 2026
Continuation of work from #161400 and #163017.

Updating stride semantics for ```clone_meta``` and underlying function, ```compute_elementwise_output_strides```.

Pull Request resolved: #164252
Approved by: https://github.com/Lucaskabela
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
Continuation of work from pytorch#161400 and pytorch#163017.

Updating stride semantics for ```clone_meta``` and underlying function, ```compute_elementwise_output_strides```.

Pull Request resolved: pytorch#164252
Approved by: https://github.com/Lucaskabela
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.compile doesn't preserve stride with clone(memory_format=torch.preserve_format)

6 participants