Read out real strides from compilation result, rather than real args#105010
Read out real strides from compilation result, rather than real args#105010ezyang wants to merge 3 commits intogh/ezyang/2222/basefrom
Conversation
This prefigures a refactor that will move the backward compilation to entirely ahead of time, so I need to extract these strides some other way. Straight from the compiler's mouth will do it. Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105010
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e2a8ecd: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| # compiler to aot_autograd | ||
| # Per output, what the compiler specified stride of the output is, | ||
| # or None if no stride is known | ||
| self.output_strides: Optional[List[Optional[List[int]]]] = None |
There was a problem hiding this comment.
Right now we do layout optimization only if dynamic shape is disabled. But we are working on resolving that (blocked on the split reduction support). So, should we set the type here to consider SymInt?
There was a problem hiding this comment.
yeah this can have symints, will amend
| # Return the output strides to the caller via TracingContext | ||
| assert len(context.output_strides) == 0 | ||
| for out in graph.graph_outputs: | ||
| if hasattr(out, "layout"): |
There was a problem hiding this comment.
Do we see cases that out does not have a layout attribute? Does this happen for some SymInt returned?
There was a problem hiding this comment.
yes, can have symint return. Maybe there is something better than hasattr to do here, but a lot of inductor code is written this way. Hard to be more precise without typing 👹
There was a problem hiding this comment.
OK, turns out I can't, because strides in inductor are sympy.Symbol, not SymInt. So I am just going ahead and storing the hints here only. When you fix this to do permutations instead, probably can make this a little less fragile.
There was a problem hiding this comment.
We can not do permutations instead because of some non 'dense' activations? This happens for some real models.
There was a problem hiding this comment.
There are a few ways we can do this. If we only ever change layout on dense tensors, we can make output_strides be None unless we changed the stride. Then the permutation is always defined.
There was a problem hiding this comment.
If we only ever change layout on dense tensors, ...
I think this can not be guaranteed in inductor right now. A non dense tensor's layout may get changes because of the layout change of its upstream tensors. And inductor does not force eager stride for that non dense tensor because of code here: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/graph.py#L688 . I guess the algorithm will be tricky to restride a non-dense tensor.
There was a problem hiding this comment.
OK will need to think this through carefully. I'll start a doc.
There was a problem hiding this comment.
Upon further reflection, all of this is moot if we have compiled backwards, so let's not touch it unless it's causing someone problems.
| if tc is None: | ||
| yield None | ||
| return | ||
| old_output_strides = tc.output_strides |
There was a problem hiding this comment.
Just want to raise a concern about nested compiling where we may need a stack for output_strides. But I can not come up with a realistic example. So probably the current implementation is good enough
There was a problem hiding this comment.
This implicitly is a stack via old_output_strides!
There was a problem hiding this comment.
ah, right, that's the whole point of context manager after all... haha
| continue | ||
|
|
||
| # Comparing ph_arg.stride() with real_arg.stride() directly may | ||
| if forward_saved_for_backwards_strides is None: |
There was a problem hiding this comment.
Move this check out of the for loop may be slightly faster.
There was a problem hiding this comment.
I was trying to save myself an indent 😂 This is only during compilation so I don't think it matters much
torch/_functorch/aot_autograd.py
Outdated
| if real_stride is None: | ||
| continue | ||
|
|
||
| assert _get_hints(real_stride) == all_args[i].stride(), f"{real_stride} {all_args[i].stride()}" |
There was a problem hiding this comment.
I assume you will remove this soon since the main point of the change is we won't have access to all_args?
There was a problem hiding this comment.
Yup, this is just for CI here.
| forward_saved_for_backwards_strides = fwd_output_strides[ | ||
| CompiledFunction.metadata.tensors_saved_for_backwards_slice | ||
| ] |
… real args" This prefigures a refactor that will move the backward compilation to entirely ahead of time, so I need to extract these strides some other way. Straight from the compiler's mouth will do it. I can't easily get the information via the return result of `fw_compiler` without changing the calling convention, so instead I smuggle it via TracingContext. TracingContext may be None when we are compiling patterns for the joint graph pattern matcher. Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
… real args" This prefigures a refactor that will move the backward compilation to entirely ahead of time, so I need to extract these strides some other way. Straight from the compiler's mouth will do it. I can't easily get the information via the return result of `fw_compiler` without changing the calling convention, so instead I smuggle it via TracingContext. TracingContext may be None when we are compiling patterns for the joint graph pattern matcher. Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…05251) Currently all information about the dependencies of ghstack PRs (e.g. #105010) is stripped away: https://github.com/pytorch/pytorch/blob/c984885809194e0a807b3f5543450fae4dfa841a/.github/scripts/trymerge.py#L1077-L1078 This PR adds this information back in a more compact form. All dependencies (PR numbers) of each PR in ghstack are recorded. The resulting commit message will look like this (the last line is new): > Mock title (#123) > > Mock body text > Pull Request resolved: #123 > Approved by: https://github.com/Approver1, https://github.com/Approver2 > ghstack dependencies: #1, #2 --- ### Testing Unit tests. --- ### Note Re: `# type: ignore[assignment]` in unit tests. I did my due diligence to find alternatives. Unfortunately mypy [doesn't](python/mypy#6713) support this [way of patching methods](https://docs.python.org/3/library/unittest.mock-examples.html#mock-patching-methods), and the alternatives are either extremely verbose or don't work for this case. I decided it's not worth the effort (since the problem is limited only to the unit test). Pull Request resolved: #105251 Approved by: https://github.com/huydhn
Stack from ghstack (oldest at bottom):
This prefigures a refactor that will move the backward compilation
to entirely ahead of time, so I need to extract these strides some
other way. Straight from the compiler's mouth will do it.
I can't easily get the information via the return result of
fw_compilerwithout changing the calling convention, so instead I smuggle it via TracingContext. TracingContext may be None when we are compiling patterns for the joint graph pattern matcher.Signed-off-by: Edward Z. Yang ezyang@meta.com
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8