|
1 | 1 | #!/usr/bin/python3 |
2 | 2 | # mypy: allow-untyped-defs |
3 | | -import atexit |
4 | 3 | import importlib |
5 | | -import logging |
6 | | -import os |
| 4 | +import importlib.util |
7 | 5 | import sys |
8 | | -import tempfile |
9 | 6 | from typing import Optional |
10 | 7 |
|
11 | 8 | import torch |
|
14 | 11 | ) |
15 | 12 |
|
16 | 13 |
|
17 | | -logger = logging.getLogger(__name__) |
18 | | - |
19 | | - |
20 | 14 | _FILE_PREFIX = "_remote_module_" |
21 | | -_TEMP_DIR = tempfile.TemporaryDirectory() |
22 | | -INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name |
23 | | -atexit.register(_TEMP_DIR.cleanup) |
24 | | -logger.info("Created a temporary directory at %s", INSTANTIATED_TEMPLATE_DIR_PATH) |
25 | | -sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH) |
26 | 15 |
|
27 | 16 |
|
28 | 17 | def get_arg_return_types_from_interface(module_interface): |
@@ -63,40 +52,35 @@ def get_arg_return_types_from_interface(module_interface): |
63 | 52 | return args_str, arg_types_str, return_type_str |
64 | 53 |
|
65 | 54 |
|
66 | | -def _write(out_path, text): |
67 | | - old_text: Optional[str] |
68 | | - try: |
69 | | - with open(out_path) as f: |
70 | | - old_text = f.read() |
71 | | - except OSError: |
72 | | - old_text = None |
73 | | - if old_text != text: |
74 | | - with open(out_path, "w") as f: |
75 | | - logger.info("Writing %s", out_path) |
76 | | - f.write(text) |
77 | | - else: |
78 | | - logger.info("Skipped writing %s", out_path) |
| 55 | +class StringLoader(importlib.abc.SourceLoader): |
| 56 | + def __init__(self, data): |
| 57 | + self.data = data |
| 58 | + |
| 59 | + def get_source(self, fullname): |
| 60 | + return self.data |
| 61 | + |
| 62 | + def get_data(self, path): |
| 63 | + return self.data.encode("utf-8") |
| 64 | + |
| 65 | + def get_filename(self, fullname): |
| 66 | + return fullname |
79 | 67 |
|
80 | 68 |
|
81 | 69 | def _do_instantiate_remote_module_template( |
82 | 70 | generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda |
83 | 71 | ): |
84 | | - generated_code_text = get_remote_module_template( |
85 | | - enable_moving_cpu_tensors_to_cuda |
86 | | - ).format(**str_dict) |
87 | | - out_path = os.path.join( |
88 | | - INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py" |
| 72 | + if generated_module_name in sys.modules: |
| 73 | + return sys.modules[generated_module_name] |
| 74 | + |
| 75 | + spec = importlib.util.spec_from_loader( |
| 76 | + generated_module_name, |
| 77 | + StringLoader(get_remote_module_template(enable_moving_cpu_tensors_to_cuda).format(**str_dict)), |
| 78 | + origin='torch-jit', |
89 | 79 | ) |
90 | | - _write(out_path, generated_code_text) |
91 | | - |
92 | | - # From importlib doc, |
93 | | - # > If you are dynamically importing a module that was created since |
94 | | - # the interpreter began execution (e.g., created a Python source file), |
95 | | - # you may need to call invalidate_caches() in order for the new module |
96 | | - # to be noticed by the import system. |
97 | | - importlib.invalidate_caches() |
98 | | - generated_module = importlib.import_module(f"{generated_module_name}") |
99 | | - return generated_module |
| 80 | + module = importlib.util.module_from_spec(spec) |
| 81 | + sys.modules[generated_module_name] = module |
| 82 | + spec.loader.exec_module(module) |
| 83 | + return module |
100 | 84 |
|
101 | 85 |
|
102 | 86 | def instantiate_scriptable_remote_module_template( |
|
0 commit comments