[dtensor] avoid shape recompilations on DTensorSpec#163820
[dtensor] avoid shape recompilations on DTensorSpec#163820
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163820
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 87a3e3d with merge base f63d16c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
azahed98
left a comment
There was a problem hiding this comment.
Changes lgtm. Let me know when final review is needed.
azahed98
left a comment
There was a problem hiding this comment.
Lgtm! Thanks for the change
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 3, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
…dtensor_shape_metadata_guard
|
Sorry hold up, why is this the right thing to do? |
|
Is the claim that DtensorSpec size/stride always equal to the outer tensor size/stride? |
Yepp, the metadata check was on the original eager DTensorSpec (sample input), which hardcoded the static outer sizes/strides I think the dynamo ShapeEnv guards should be enough to handle checking any size/stride constraints? |
| raise RuntimeError("Unsupported tensor type!") | ||
|
|
||
| @classmethod | ||
| def __metadata_guard__(cls, orig, other): |
There was a problem hiding this comment.
Nit: return type annotation
…dtensor_shape_metadata_guard
…om/pytorch/pytorch into pianpwk/dtensor_shape_metadata_guard
|
This feels incomplete. When I trace a DTensor with dynamic shapes, does the DTensorSpec size/stride become symbolic? Because if it doesn't, then that seems like a problem. And if they do become symbolic, then why would we end up guarding on them? |
I just double checked, and the DTensorSpec does contain symints during tracing.
I'm admittedly not that familiar with how the guards are constructed, but I was going off #152963, under "additional work" point 1, allegedly the symints are undesirable for constructing the guards and cause recompiles. @bdhirsh Could you provide more context? |
|
My understanding is that the DTensorSpec that contains symints, is the one in the DTensor wrapped by dynamo (appearing at the top of the dynamo graph), so fake tensor prop is done dynamically, and that's all good. But the problem is that the tensor subclass metadata guards (TENSOR_SUBCLASS_METADATA_MATCH) are installed against the pre-wrapped, eager mode DTensor, so we were previously just guarding on the original static shapes, instead of SymInts. I'm not sure if we can just install guards against the wrapped DTensor instead? |
|
It is a smell to me that |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
skips DTensorSpec.sizes/strides in metadata guard checks Pull Request resolved: pytorch#163820 Approved by: https://github.com/azahed98
skips DTensorSpec.sizes/strides in metadata guard checks
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci