You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Opening this so I can discuss with @albanD
I built a proof of concept of an in place API for an nn.Module that allows us to save and load a torch.compiled model with no issues https://github.com/msaroufim/mlsys-experiments/blob/main/save-compiled-model.py
So users can run` model.compile()` and then run `torch.save(model, "model.pt")` and `torch.load(model, "model.pt)` with no issues unlike the rather strange current suggestion we give to users which is `opt_mod = torch.compile(mod); torch.save(mod, "model.pt")`
Right now I'm trying to extend this to work for nn.modules more generally
TODO: Failing tests
* [x] torch.jit.load -> issue was because of aliasing `__call__` to `_call_impl`, _call_impl used to be skipped when now it lo longer is so expanded the skip check. I added an explicit `torch.jit.load()` test now which @davidberard98 suggested
* [x] functorch seems to be a flake - ran locally and it worked `pytest functorch/test_eager_transforms.py`
* [x] a test infra flake - `test_testing.py::TestImports::test_no_mutate_global_logging_on_import_path_functorch`
* [x] It seems like I broke inlining in dynamo though `python -m pytest test/dynamo/test_dynamic_shapes.py -k test_issue175` chatting with Voz about it but still not entirely sure how to fix - found a workaround after chatting with @yanboliang
* [x] `pytest test/dynamo/test_modules.py` and `test/dynamo/test_dynamic_shapes` `test/dynamo/test_misc.py` seem to be failing in CI but trying it out locally they all pass tests passed with 0 failures
* [x] `pytest test/profiler/test_profiler_tree.py ` these tests have ProfilerTrees explicitly printed and will now break if __call__ is not in tree - ran with `EXPECT_ACCEPT=1`
* [x] `pytest test/test_torch.py::TestTorch::test_typed_storage_deprecation_warning` a flake, ran this locally and it works fine
* [x] I reverted my changes to `_dynamo/nn_module.py` since it looks like @wconstab is now directly handling `_call_impl` there but this is triggering an infinite inlining which is crashing
* [x] Tried out to instead override `__call__`, python doesnt like this though #97565 (comment)
Pull Request resolved: #97565
Approved by: https://github.com/aaronenyeshi, https://github.com/albanD, https://github.com/voznesenskym
0 commit comments