Nested tensor subclass support#127431
Nested tensor subclass support#127431tugsbayasgalan wants to merge 20 commits intogh/tugsbayasgalan/220/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127431
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (6 Unrelated Failures)As of commit 4b0160f with merge base 78e40b2 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| sub = t.type.__tensor_unflatten__( | ||
| transformed_tensors_dict, t.ctx, outer_size, outer_stride | ||
| ) | ||
| todo = plain_meta_tensors |
There was a problem hiding this comment.
Can you just do it recursively? I don't think you'll stack overflow and I think it will be a lot easier to understand
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
| return inner_t | ||
|
|
||
| attr_fqn = prefix + "." + attr if prefix != "" else attr | ||
| attr_list = attr_fqn.split(".") |
There was a problem hiding this comment.
If all you're going to do to the attr_fqn is split it, why not just pass around a list
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
| current_context = symbolic_context.inner_contexts[attr] | ||
|
|
||
| current_source = AttrSource(source, attr) | ||
| new_empty_tensor = _empty_create_subclass( |
There was a problem hiding this comment.
You don't have to fix it here, but there's a somewhat prevalent antipattern in this file of doing small recursions on helper functions, rather than calling all the way back to the very top level. I think it should be OK to recurse to the very top call function, and that makes things more general since you can handle composition of things with other things the small helpers don't help. Just calling attention to this.
| sub = t.type.__tensor_unflatten__( | ||
| transformed_tensors_dict, t.ctx, outer_size, outer_stride | ||
| sub = _empty_create_subclass( | ||
| t, outer_size, outer_stride, symbolic_context, callback, source |
There was a problem hiding this comment.
ACKing this part, I'll let Brian do the rest
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
| # TODO: figure out how to refactor the backward properly | ||
| # so I can use aot_dispatch_subclass_wrapper() here. | ||
| if CompiledFunction.maybe_subclass_metadata is not None: | ||
| tangents = all_args[tangents_start_idx:tangents_end_idx] |
There was a problem hiding this comment.
This seems kinda weird way of detecting wrong tangents but i guess this is best we can do?
| curr_start_idx = self.flat_tensor_start_idx | ||
| for attr, creation_meta in self.attrs.items(): | ||
| if creation_meta is None: | ||
| subclass = all_args[curr_start_idx] |
There was a problem hiding this comment.
nit: the variable name subclass here seems misleading, since it may or not actually be a subclass (my understanding is that if creation_meta is None, this is guaranteed to be a plain tensor.
Maybe inner_tensor?
| @@ -171,47 +171,56 @@ class SubclassCreationMeta: | |||
| flat_tensor_start_idx: int | |||
| # The number of tensors that live in this subclass wrapper | |||
| arg_count: int | |||
There was a problem hiding this comment.
after reading the code, some invariants that I think are worth explicitly mentioning in the comments:
arg_countis inclusive of the arg_counts of any inner tensor subclasses: If I have a TwoTensor and both of its inner elements are TwoTensors, then thearg_countof the outer-most sublass will be 4
| curr_start_idx += creation_meta.arg_count | ||
| inner_tensors[attr] = subclass | ||
|
|
||
| rebuilt = type(self.original_subclass).__tensor_unflatten__( |
There was a problem hiding this comment.
All of the indices in this reconstruction are definitely non-trivial. It would be great if we had some runtime debug-asserts we could run that would tell us if we messed up the indexing somewhere, so we get a less cryptic error if we get this wrong 🤔. I can't think of a great way to do this though, unless we do something like save all of the shapes of the inner tensors at trace time and assert that our reconstructed inner tensors are have the same shape at runtime
| z = x.clone().detach().requires_grad_() | ||
| z_compile = z.clone().detach().requires_grad_() | ||
|
|
||
| out_eager = f(x_nested, y_nested, z) |
There was a problem hiding this comment.
hmm... more out of paranoia than anything else, I'm worried about more complicated sets of inputs. The inputs to this test are something like Two(Two(plain, plain), Two(plain, plain)), plain, plain.
Some more testing ideas:
(1) Add a fourth argument that is an unbalanced TwoTensor, e.g. Two(plain, Two(plain, plain))
(2) add different subclass types into the test: e.g. make one input ConstantMetadataTensor(plain), and another a TwoTensor(plain, ConstantMetadataTensor(plain)).
bdhirsh
left a comment
There was a problem hiding this comment.
left some more nits and more tests would be great, pre-emptively stamping!
|
curious are we landing this PR soon? It's helpful in addressing IMA issues when compiling DTensor(local=fp8). Super valuable work! sharing my 2 cents perfs. For cpu overhead and gpu time, computing fp8 amax in eager is still faster than torch.compile #129457 |
When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang Differential Revision: [D58533224](https://our.internmc.facebook.com/intern/diff/D58533224) [ghstack-poisoned]
|
This pull request was exported from Phabricator. Differential Revision: D58533224 |
Will try to land today :) |
When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang Differential Revision: [D58533224](https://our.internmc.facebook.com/intern/diff/D58533224) [ghstack-poisoned]
Pull Request resolved: #127431 When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive. cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @imported-using-ghimport Differential Revision: [D58533224](https://our.internmc.facebook.com/intern/diff/D58533224/) ghstack-source-id: 21cebdb
|
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: `unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431 so we can remove this workaround now Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: `unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431 so we can remove this workaround now Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: `unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431 so we can remove this workaround now Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: `unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431 so we can remove this workaround now Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Stack from ghstack (oldest at bottom):
When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang
Differential Revision: D58533224