Skip to content

Commit bf52d57

Browse files
Mark Saroufimpytorchmergebot
authored andcommitted
torch.save/load torch.compiled models (#97565)
Opening this so I can discuss with @albanD I built a proof of concept of an in place API for an nn.Module that allows us to save and load a torch.compiled model with no issues https://github.com/msaroufim/mlsys-experiments/blob/main/save-compiled-model.py So users can run` model.compile()` and then run `torch.save(model, "model.pt")` and `torch.load(model, "model.pt)` with no issues unlike the rather strange current suggestion we give to users which is `opt_mod = torch.compile(mod); torch.save(mod, "model.pt")` Right now I'm trying to extend this to work for nn.modules more generally TODO: Failing tests * [x] torch.jit.load -> issue was because of aliasing `__call__` to `_call_impl`, _call_impl used to be skipped when now it lo longer is so expanded the skip check. I added an explicit `torch.jit.load()` test now which @davidberard98 suggested * [x] functorch seems to be a flake - ran locally and it worked `pytest functorch/test_eager_transforms.py` * [x] a test infra flake - `test_testing.py::TestImports::test_no_mutate_global_logging_on_import_path_functorch` * [x] It seems like I broke inlining in dynamo though `python -m pytest test/dynamo/test_dynamic_shapes.py -k test_issue175` chatting with Voz about it but still not entirely sure how to fix - found a workaround after chatting with @yanboliang * [x] `pytest test/dynamo/test_modules.py` and `test/dynamo/test_dynamic_shapes` `test/dynamo/test_misc.py` seem to be failing in CI but trying it out locally they all pass tests passed with 0 failures * [x] `pytest test/profiler/test_profiler_tree.py ` these tests have ProfilerTrees explicitly printed and will now break if __call__ is not in tree - ran with `EXPECT_ACCEPT=1` * [x] `pytest test/test_torch.py::TestTorch::test_typed_storage_deprecation_warning` a flake, ran this locally and it works fine * [x] I reverted my changes to `_dynamo/nn_module.py` since it looks like @wconstab is now directly handling `_call_impl` there but this is triggering an infinite inlining which is crashing * [x] Tried out to instead override `__call__`, python doesnt like this though #97565 (comment) Pull Request resolved: #97565 Approved by: https://github.com/aaronenyeshi, https://github.com/albanD, https://github.com/voznesenskym
1 parent 2f95380 commit bf52d57

6 files changed

Lines changed: 204 additions & 86 deletions

File tree

test/dynamo/test_compile.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Owner(s): ["module: dynamo"]
2+
3+
import os
4+
import tempfile
5+
import unittest
6+
7+
import torch
8+
from torch._dynamo.testing import CompileCounter
9+
10+
11+
class ToyModel(torch.nn.Module):
12+
def __init__(self):
13+
super(ToyModel, self).__init__()
14+
self.linear = torch.nn.Linear(10, 10)
15+
self.relu = torch.nn.ReLU()
16+
17+
def forward(self, x):
18+
return self.relu(self.linear(x))
19+
20+
21+
class InPlaceCompilationTests(unittest.TestCase):
22+
def test_compilation(self):
23+
torch._dynamo.reset()
24+
model = ToyModel()
25+
cnt = CompileCounter()
26+
model.compile(backend=cnt)
27+
x = torch.randn(10, 10)
28+
model(x)
29+
self.assertEqual(cnt.frame_count, 1)
30+
31+
def test_overwrite_call_impl(self):
32+
torch._dynamo.reset()
33+
model = ToyModel()
34+
self.assertTrue(model._compiled_call_impl is None)
35+
model.compile()
36+
self.assertTrue(model._compiled_call_impl is not None)
37+
38+
def test_save(self):
39+
torch._dynamo.reset()
40+
model = ToyModel()
41+
model.compile()
42+
model(torch.randn(1, 10))
43+
44+
with tempfile.TemporaryDirectory() as tmpdirname:
45+
torch.save(model, os.path.join(tmpdirname, "model.pt"))
46+
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
47+
loaded_model(torch.randn(1, 10))
48+
49+
def test_state_dict_save(self):
50+
torch._dynamo.reset()
51+
model = ToyModel()
52+
model.compile()
53+
model(torch.randn(1, 10))
54+
with tempfile.TemporaryDirectory() as tmpdirname:
55+
torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt"))
56+
loaded_model = ToyModel()
57+
loaded_model.load_state_dict(
58+
torch.load(os.path.join(tmpdirname, "model.pt"))
59+
)
60+
loaded_model(torch.randn(1, 10))
61+
62+
def test_jit_save(self):
63+
torch._dynamo.reset()
64+
model = ToyModel()
65+
model.compile()
66+
model(torch.randn(1, 10))
67+
scripted_model = torch.jit.script(model)
68+
with tempfile.TemporaryDirectory() as tmpdirname:
69+
torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
70+
loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
71+
loaded_model(torch.randn(1, 10))

test/profiler/test_profiler_tree.py

Lines changed: 88 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,10 @@ def test_profiler_experimental_tree(self):
292292
autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
293293
torch::autograd::AccumulateGrad
294294
aten::detach
295-
detach"""
295+
detach
296+
cudaGetDeviceCount
297+
cudaGetDeviceCount
298+
cudaGetDeviceProperties"""
296299
)
297300

298301
@ProfilerTree.test
@@ -542,87 +545,95 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
542545
aten::empty
543546
aten::fill_
544547
nn.Module: MyModule_0
545-
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
546-
test_profiler_tree.py(...): forward
547-
nn.Module: ReLU_0
548-
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
549-
torch/nn/modules/activation.py(...): forward
550-
torch/nn/functional.py(...): relu
551-
<built-in function _has_torch_function_unary>
552-
<built-in method relu of type object at 0xXXXXXXXXXXXX>
553-
aten::relu
554-
aten::clamp_min
555-
nn.Module: Linear_0
556-
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
557-
torch/nn/modules/linear.py(...): forward
558-
torch/nn/modules/module.py(...): __getattr__
559-
torch/nn/modules/module.py(...): __getattr__
560-
<built-in function linear>
561-
aten::linear
562-
aten::t
563-
aten::transpose
564-
aten::as_strided
565-
aten::matmul
566-
aten::unsqueeze
567-
aten::as_strided
568-
aten::mm
569-
aten::resolve_conj
570-
aten::resolve_conj
571-
aten::resolve_conj
572-
aten::squeeze_
573-
aten::as_strided_
574-
aten::add_
575-
nn.Module: ReLU_1
576-
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
577-
torch/nn/modules/activation.py(...): forward
578-
torch/nn/functional.py(...): relu
579-
<built-in function _has_torch_function_unary>
580-
<built-in method relu of type object at 0xXXXXXXXXXXXX>
581-
aten::relu
582-
aten::clamp_min
548+
torch/nn/modules/module.py(...): _call_impl
549+
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
550+
test_profiler_tree.py(...): forward
551+
nn.Module: ReLU_0
552+
torch/nn/modules/module.py(...): _call_impl
553+
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
554+
torch/nn/modules/activation.py(...): forward
555+
torch/nn/functional.py(...): relu
556+
<built-in function _has_torch_function_unary>
557+
<built-in method relu of type object at 0xXXXXXXXXXXXX>
558+
aten::relu
559+
aten::clamp_min
560+
nn.Module: Linear_0
561+
torch/nn/modules/module.py(...): _call_impl
562+
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
563+
torch/nn/modules/linear.py(...): forward
564+
torch/nn/modules/module.py(...): __getattr__
565+
torch/nn/modules/module.py(...): __getattr__
566+
<built-in function linear>
567+
aten::linear
568+
aten::t
569+
aten::transpose
570+
aten::as_strided
571+
aten::matmul
572+
aten::unsqueeze
573+
aten::as_strided
574+
aten::mm
575+
aten::resolve_conj
576+
aten::resolve_conj
577+
aten::resolve_conj
578+
aten::squeeze_
579+
aten::as_strided_
580+
aten::add_
581+
nn.Module: ReLU_1
582+
torch/nn/modules/module.py(...): _call_impl
583+
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
584+
torch/nn/modules/activation.py(...): forward
585+
torch/nn/functional.py(...): relu
586+
<built-in function _has_torch_function_unary>
587+
<built-in method relu of type object at 0xXXXXXXXXXXXX>
588+
aten::relu
589+
aten::clamp_min
583590
<built-in method ones of type object at 0xXXXXXXXXXXXX>
584591
aten::ones
585592
aten::empty
586593
aten::fill_
587594
nn.Module: MyModule_0
588-
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
589-
test_profiler_tree.py(...): forward
590-
nn.Module: ReLU_0
591-
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
592-
torch/nn/modules/activation.py(...): forward
593-
torch/nn/functional.py(...): relu
594-
<built-in function _has_torch_function_unary>
595-
<built-in method relu of type object at 0xXXXXXXXXXXXX>
596-
aten::relu
597-
aten::clamp_min
598-
nn.Module: Linear_0
599-
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
600-
torch/nn/modules/linear.py(...): forward
601-
torch/nn/modules/module.py(...): __getattr__
602-
torch/nn/modules/module.py(...): __getattr__
603-
<built-in function linear>
604-
aten::linear
605-
aten::t
606-
aten::transpose
607-
aten::as_strided
608-
aten::matmul
609-
aten::unsqueeze
610-
aten::as_strided
611-
aten::mm
612-
aten::resolve_conj
613-
aten::resolve_conj
614-
aten::resolve_conj
615-
aten::squeeze_
616-
aten::as_strided_
617-
aten::add_
618-
nn.Module: ReLU_1
619-
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
620-
torch/nn/modules/activation.py(...): forward
621-
torch/nn/functional.py(...): relu
622-
<built-in function _has_torch_function_unary>
623-
<built-in method relu of type object at 0xXXXXXXXXXXXX>
624-
aten::relu
625-
aten::clamp_min
595+
torch/nn/modules/module.py(...): _call_impl
596+
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
597+
test_profiler_tree.py(...): forward
598+
nn.Module: ReLU_0
599+
torch/nn/modules/module.py(...): _call_impl
600+
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
601+
torch/nn/modules/activation.py(...): forward
602+
torch/nn/functional.py(...): relu
603+
<built-in function _has_torch_function_unary>
604+
<built-in method relu of type object at 0xXXXXXXXXXXXX>
605+
aten::relu
606+
aten::clamp_min
607+
nn.Module: Linear_0
608+
torch/nn/modules/module.py(...): _call_impl
609+
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
610+
torch/nn/modules/linear.py(...): forward
611+
torch/nn/modules/module.py(...): __getattr__
612+
torch/nn/modules/module.py(...): __getattr__
613+
<built-in function linear>
614+
aten::linear
615+
aten::t
616+
aten::transpose
617+
aten::as_strided
618+
aten::matmul
619+
aten::unsqueeze
620+
aten::as_strided
621+
aten::mm
622+
aten::resolve_conj
623+
aten::resolve_conj
624+
aten::resolve_conj
625+
aten::squeeze_
626+
aten::as_strided_
627+
aten::add_
628+
nn.Module: ReLU_1
629+
torch/nn/modules/module.py(...): _call_impl
630+
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
631+
torch/nn/modules/activation.py(...): forward
632+
torch/nn/functional.py(...): relu
633+
<built-in function _has_torch_function_unary>
634+
<built-in method relu of type object at 0xXXXXXXXXXXXX>
635+
aten::relu
636+
aten::clamp_min
626637
torch/profiler/profiler.py(...): __exit__
627638
torch/profiler/profiler.py(...): stop
628639
..."""

torch/_dynamo/eval_frame.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ def __call__(self, fn):
246246
filename = None
247247
if (
248248
(filename is None or skipfiles.check(filename))
249-
and (getattr(fn, "__name__", "") != "_call_impl")
249+
and (
250+
getattr(fn, "__name__", "") not in ["_call_impl", "_wrapped_call_impl"]
251+
)
250252
and filename not in DONT_WRAP_FILES
251253
):
252254
# call to a builtin without a frame for us to capture

torch/_dynamo/variables/nn_module.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,14 @@ def call_function(
318318
# the call_wrapped currently, and maybe other issues too
319319
fn = mod.forward
320320
else:
321-
fn = mod.__call__
321+
fn = mod._call_impl
322322
fn_source = AttrSource(self.source, "__call__")
323323
if istype(fn, types.MethodType):
324324
fn = fn.__func__
325325
fn_source = AttrSource(fn_source, "__func__")
326326
args = [self] + args
327327
else:
328328
assert istype(fn, types.FunctionType)
329-
330329
options["source"] = fn_source
331330
return tx.inline_user_function_return(
332331
variables.UserFunctionVariable(fn, **options),
@@ -374,7 +373,7 @@ def generic_call_method_helper(name):
374373
**options,
375374
)
376375

377-
if name == "_call_impl":
376+
if name in ["_call_impl", "_wrapped_call_impl"]:
378377
# Example: `self.layer.__call__(x)`
379378
# This is used for explicit calling `__call__` in a forward function.
380379
# Dynamo inlines `__call__`, includes hooks.
@@ -683,14 +682,12 @@ def call_function(
683682
) -> "VariableTracker":
684683
options = VariableTracker.propagate(self, args, kwargs.values())
685684
mod = self.value
686-
687685
# see comment on lazy module handling in NNModuleVariable.call_function for context
688686
if is_lazy_module(mod):
689687
if mod.cls_to_become is not None:
690688
self.value_type = mod.cls_to_become
691689
initialize_lazy_module(tx, mod, args, kwargs)
692-
693-
name = "__call__"
690+
name = "_call_impl"
694691
fn = getattr(self.value_type, name)
695692
if self.source:
696693
source = AttrSource(AttrSource(self.source, "__class__"), name)
@@ -711,6 +708,16 @@ def call_method(
711708
from .builder import VariableBuilder
712709

713710
options = VariableTracker.propagate(self, args, kwargs.values())
711+
if name in ["_call_impl", "_wrapped_call_impl"]:
712+
fn = getattr(self.value_type, name)
713+
if self.source:
714+
source = AttrSource(AttrSource(self.source, "__class__"), name)
715+
else:
716+
source = None
717+
718+
return variables.UserFunctionVariable(
719+
fn, source=source, **options
720+
).call_function(tx, [self] + list(args), kwargs)
714721

715722
if name not in getattr(self.value, "__dict__", {}):
716723
try:

torch/jit/_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def fail(self, *args, **kwargs):
944944
return fail
945945

946946
for name, method in _get_methods(torch.nn.Module):
947-
if name.startswith("__"):
947+
if name.startswith("__") or name.endswith("_call_impl"):
948948
continue
949949
if (
950950
name not in RecursiveScriptModule.__dict__

torch/nn/modules/module.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,11 +433,15 @@ def forward(self, x):
433433
_load_state_dict_post_hooks: Dict[int, Callable]
434434
_modules: Dict[str, Optional['Module']]
435435
call_super_init: bool = False
436+
_compiled_call_impl : Optional[Callable] = None
437+
438+
436439

437440
def __init__(self, *args, **kwargs) -> None:
438441
"""
439442
Initializes internal Module state, shared by both nn.Module and ScriptModule.
440443
"""
444+
441445
torch._C._log_api_usage_once("python.nn_module")
442446

443447
# Backward compatibility: no args used to be allowed when call_super_init=False
@@ -1491,6 +1495,12 @@ def _slow_forward(self, *input, **kwargs):
14911495
tracing_state.pop_scope()
14921496
return result
14931497

1498+
def _wrapped_call_impl(self, *args, **kwargs):
1499+
if self._compiled_call_impl is not None:
1500+
return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1501+
else:
1502+
return self._call_impl(*args, **kwargs)
1503+
14941504
def _call_impl(self, *args, **kwargs):
14951505
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
14961506
# If we don't have any hooks, we want to skip the rest of the logic in
@@ -1572,10 +1582,16 @@ def _call_impl(self, *args, **kwargs):
15721582

15731583
return result
15741584

1575-
__call__ : Callable[..., Any] = _call_impl
1585+
__call__ : Callable[..., Any] = _wrapped_call_impl
1586+
1587+
def __getstate__(self):
1588+
state = self.__dict__.copy()
1589+
state.pop("_compiled_call_impl", None)
1590+
return state
15761591

15771592
def __setstate__(self, state):
15781593
self.__dict__.update(state)
1594+
15791595
# Support loading old checkpoints that don't have the following attrs:
15801596
if '_forward_pre_hooks' not in self.__dict__:
15811597
self._forward_pre_hooks = OrderedDict()
@@ -2420,3 +2436,14 @@ def _replicate_for_data_parallel(self):
24202436
replica._is_replica = True # type: ignore[assignment]
24212437

24222438
return replica
2439+
2440+
def compile(self, *args, **kwargs):
2441+
"""
2442+
Compile this Module's forward using :func:`torch.compile`.
2443+
2444+
This Module's `__call__` method is compiled and all arguments are passed as-is
2445+
to :func:`torch.compile`.
2446+
2447+
See :func:`torch.compile` for details on the arguments for this function.
2448+
"""
2449+
self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs)

0 commit comments

Comments
 (0)