Fix tensor subclass + dynamic shapes in torch.compile + aot autograd#125941
Fix tensor subclass + dynamic shapes in torch.compile + aot autograd#125941guilhermeleobas wants to merge 75 commits intogh/guilhermeleobas/48/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125941
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c9db0c5 with merge base 3b0f393 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Fixes issue: #124619 ChangesThis PR addresses a bug in tensor subclasses and symbolic execution. Most of the changes are in the An extra field ( Test planAdd tests for two different subclasses: The set of tests is composed of functions that return a mix of subclasses |
|
What exactly is the algorithmic strategy here? |
|
@guilhermeleobas the new recursive size/stride handling looks mostly good to me. I was still a bit worried about edge cases, so I tried a few tests locally and I got a failure involving nested subclasses: (I wanted to stress test a bit so I used a version of The fails for me, with: |
Oops. It was just one small mistake where I mistype the name of a variable. Should be good now. |
|
I just notice @IvanKobzarev PR (#138498) does a micro optimization on |
|
ah yes that would be great. @IvanKobzarev has been looking into subclass runtime overhead, and it would be nice if we can avoid this PR making it too much worse |
@IvanKobzarev, did you use any code to benchmark #138498? If so, can you share it with me? |
|
|
||
| if subclass_metas is None: | ||
| xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x))) | ||
| get_plain_tensors(typing.cast(Tensor, x), out_append_list=xs_inner) |
| subclass: Tensor, out_append_list: Optional[List[Tensor]] = None | ||
| ) -> List[Tensor]: | ||
| subclass: Tensor, out_append_list: Optional[List[Union[Tensor, int, SymInt]]] = None | ||
| ) -> List[Union[Tensor, int, SymInt]]: |
There was a problem hiding this comment.
hmm, the type signature here is a bit confusing, since we never actually append ints/SymInts to the list in this function. I guess you needed this because in the out_append_list= case, the list we pass in might have symints in it already?
Instead, what do you think of: just refactoring this function to always accept an output list to append to, and mandating that anybody using this API must pass in their own list (from a quick grep there are only 2 call sites of this function, both within AOTAutograd)
bdhirsh
left a comment
There was a problem hiding this comment.
left a few more comments, but otherwise I think this is ready to land. Thanks for all the hard work!
Hi, At the moment I use for profiling : And just manual average counting of time.time_ns() in global variable of unwrap_tensor_subclasses() in runtime_wrappers.py |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@guilhermeleobas can that call be removed now? I think it's still there with the note that it can be removed after this PR closed. |
|
oh @mlazos our current hypothesis is that this context manager was only needed because there were some tests that did |
|
Awesome I will try that |
Stack from ghstack (oldest at bottom):
outer_size/outer_stride#133337cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec @XilunWu @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @tianyu-l @peterbell10