Skip to content

Commit a633a85

Browse files
apmortonpytorchmergebot
authored andcommitted
Avoid writing temporary modules to disk
1 parent a2b6afe commit a633a85

1 file changed

Lines changed: 24 additions & 40 deletions

File tree

torch/distributed/nn/jit/instantiator.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
#!/usr/bin/python3
22
# mypy: allow-untyped-defs
3-
import atexit
43
import importlib
5-
import logging
6-
import os
4+
import importlib.util
75
import sys
8-
import tempfile
96
from typing import Optional
107

118
import torch
@@ -14,15 +11,7 @@
1411
)
1512

1613

17-
logger = logging.getLogger(__name__)
18-
19-
2014
_FILE_PREFIX = "_remote_module_"
21-
_TEMP_DIR = tempfile.TemporaryDirectory()
22-
INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name
23-
atexit.register(_TEMP_DIR.cleanup)
24-
logger.info("Created a temporary directory at %s", INSTANTIATED_TEMPLATE_DIR_PATH)
25-
sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
2615

2716

2817
def get_arg_return_types_from_interface(module_interface):
@@ -63,40 +52,35 @@ def get_arg_return_types_from_interface(module_interface):
6352
return args_str, arg_types_str, return_type_str
6453

6554

66-
def _write(out_path, text):
67-
old_text: Optional[str]
68-
try:
69-
with open(out_path) as f:
70-
old_text = f.read()
71-
except OSError:
72-
old_text = None
73-
if old_text != text:
74-
with open(out_path, "w") as f:
75-
logger.info("Writing %s", out_path)
76-
f.write(text)
77-
else:
78-
logger.info("Skipped writing %s", out_path)
55+
class StringLoader(importlib.abc.SourceLoader):
56+
def __init__(self, data):
57+
self.data = data
58+
59+
def get_source(self, fullname):
60+
return self.data
61+
62+
def get_data(self, path):
63+
return self.data.encode("utf-8")
64+
65+
def get_filename(self, fullname):
66+
return fullname
7967

8068

8169
def _do_instantiate_remote_module_template(
8270
generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
8371
):
84-
generated_code_text = get_remote_module_template(
85-
enable_moving_cpu_tensors_to_cuda
86-
).format(**str_dict)
87-
out_path = os.path.join(
88-
INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py"
72+
if generated_module_name in sys.modules:
73+
return sys.modules[generated_module_name]
74+
75+
spec = importlib.util.spec_from_loader(
76+
generated_module_name,
77+
StringLoader(get_remote_module_template(enable_moving_cpu_tensors_to_cuda).format(**str_dict)),
78+
origin='torch-jit',
8979
)
90-
_write(out_path, generated_code_text)
91-
92-
# From importlib doc,
93-
# > If you are dynamically importing a module that was created since
94-
# the interpreter began execution (e.g., created a Python source file),
95-
# you may need to call invalidate_caches() in order for the new module
96-
# to be noticed by the import system.
97-
importlib.invalidate_caches()
98-
generated_module = importlib.import_module(f"{generated_module_name}")
99-
return generated_module
80+
module = importlib.util.module_from_spec(spec)
81+
sys.modules[generated_module_name] = module
82+
spec.loader.exec_module(module)
83+
return module
10084

10185

10286
def instantiate_scriptable_remote_module_template(

0 commit comments

Comments
 (0)