Skip to content

[dtensor] avoid shape recompilations on DTensorSpec#163820

Closed
pianpwk wants to merge 7 commits intomainfrom
pianpwk/dtensor_shape_metadata_guard
Closed

[dtensor] avoid shape recompilations on DTensorSpec#163820
pianpwk wants to merge 7 commits intomainfrom
pianpwk/dtensor_shape_metadata_guard

Conversation

@pianpwk
Copy link
Copy Markdown
Contributor

@pianpwk pianpwk commented Sep 25, 2025

skips DTensorSpec.sizes/strides in metadata guard checks

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Sep 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163820

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 87a3e3d with merge base f63d16c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 25, 2025
@pianpwk pianpwk added the release notes: distributed (dtensor) release notes category label Sep 25, 2025
@pianpwk pianpwk changed the title [WIP] avoid shape recompile on dtensor [WIP][dtensor] avoid shape recompilations on DTensorSpec Sep 25, 2025
Copy link
Copy Markdown
Contributor

@azahed98 azahed98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes lgtm. Let me know when final review is needed.

@pianpwk pianpwk changed the title [WIP][dtensor] avoid shape recompilations on DTensorSpec [dtensor] avoid shape recompilations on DTensorSpec Sep 25, 2025
@pianpwk pianpwk marked this pull request as ready for review September 25, 2025 18:03
@pianpwk pianpwk requested review from azahed98 and bdhirsh September 25, 2025 18:04
Copy link
Copy Markdown
Contributor

@azahed98 azahed98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm! Thanks for the change

@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Sep 25, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 25, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 3, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 30, 2025

Sorry hold up, why is this the right thing to do?

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 30, 2025

Is the claim that DtensorSpec size/stride always equal to the outer tensor size/stride?

@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Sep 30, 2025

Is the claim that DtensorSpec size/stride always equal to the outer tensor size/stride?

Yepp, the metadata check was on the original eager DTensorSpec (sample input), which hardcoded the static outer sizes/strides

I think the dynamo ShapeEnv guards should be enough to handle checking any size/stride constraints?

Comment thread torch/distributed/tensor/_api.py Outdated
raise RuntimeError("Unsupported tensor type!")

@classmethod
def __metadata_guard__(cls, orig, other):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: return type annotation

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Oct 1, 2025

This feels incomplete. When I trace a DTensor with dynamic shapes, does the DTensorSpec size/stride become symbolic? Because if it doesn't, then that seems like a problem. And if they do become symbolic, then why would we end up guarding on them?

@azahed98
Copy link
Copy Markdown
Contributor

azahed98 commented Oct 1, 2025

When I trace a DTensor with dynamic shapes, does the DTensorSpec size/stride become symbolic?

I just double checked, and the DTensorSpec does contain symints during tracing.

And if they do become symbolic, then why would we end up guarding on them?

I'm admittedly not that familiar with how the guards are constructed, but I was going off #152963, under "additional work" point 1, allegedly the symints are undesirable for constructing the guards and cause recompiles. @bdhirsh Could you provide more context?

@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Oct 2, 2025

My understanding is that the DTensorSpec that contains symints, is the one in the DTensor wrapped by dynamo (appearing at the top of the dynamo graph), so fake tensor prop is done dynamically, and that's all good.

But the problem is that the tensor subclass metadata guards (TENSOR_SUBCLASS_METADATA_MATCH) are installed against the pre-wrapped, eager mode DTensor, so we were previously just guarding on the original static shapes, instead of SymInts.

I'm not sure if we can just install guards against the wrapped DTensor instead?

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Oct 2, 2025

It is a smell to me that __tensor_flatten__ on DTensor returns stuff that shouldn't actually be compared against. Maybe it all works out but I would worry there are other places where we assume we can test against the metadata directly that cause problems.

@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Oct 3, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
skips DTensorSpec.sizes/strides in metadata guard checks

Pull Request resolved: pytorch#163820
Approved by: https://github.com/azahed98
@github-actions github-actions Bot deleted the pianpwk/dtensor_shape_metadata_guard branch November 3, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants