Skip to content

Commit 22a4d46

Browse files
Cherry pick #129244, #129251, #129239, 129396 into release/2.4 (#129478)
* Fix allowlisting of builtins for weights_only unpickler ghstack-source-id: de329c7 Pull Request resolved: #129244 (cherry picked from commit cc99c01) * Allow NEWOBJ instruction for items added via torch.serialization.add_safe_globals ghstack-source-id: 34a8fc3 Pull Request resolved: #129251 (cherry picked from commit 50b888d) * Add warning for weights_only ghstack-source-id: ffa772c Pull Request resolved: #129239 (cherry picked from commit b3f9aa3f8f4c03b40fed53423d4a0a9340e3bd09) * Add example for torch.serialization.add_safe_globals ghstack-source-id: 6dc3275 Pull Request resolved: #129396 (cherry picked from commit ed8c36eda0f4dcf7b1d9c5eb2fb1cdccdf3fee6e)
1 parent 5608699 commit 22a4d46

5 files changed

Lines changed: 183 additions & 20 deletions

File tree

test/distributed/_tensor/test_dtensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,16 @@ def test_dtensor_save_load(self):
536536
buffer.seek(0)
537537
reloaded_st = torch.load(buffer)
538538
self.assertEqual(sharded_tensor, reloaded_st)
539+
# Test weights_only load
540+
try:
541+
torch.serialization.add_safe_globals(
542+
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
543+
)
544+
buffer.seek(0)
545+
reloaded_st = torch.load(buffer, weights_only=True)
546+
self.assertEqual(sharded_tensor, reloaded_st)
547+
finally:
548+
torch.serialization.clear_safe_globals()
539549

540550

541551
class DTensorMeshTest(DTensorTestBase):

test/test_nn.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,26 +1801,35 @@ def test_parameterlistdict_pickle(self):
18011801
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
18021802
with warnings.catch_warnings(record=True) as w:
18031803
m = pickle.loads(pickle.dumps(m))
1804-
self.assertTrue(len(w) == 0)
1804+
# warning from torch.load call in _load_from_bytes
1805+
num_warnings = 2 if torch._dynamo.is_compiling() else 1
1806+
self.assertTrue(len(w) == num_warnings)
1807+
self.assertEqual(w[0].category, FutureWarning)
18051808

18061809
# Test whether loading from older checkpoints works without triggering warnings
18071810
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
18081811
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
18091812
with warnings.catch_warnings(record=True) as w:
18101813
m = pickle.loads(pickle.dumps(m))
1811-
self.assertTrue(len(w) == 0)
1814+
# warning from torch.load call in _load_from_bytes
1815+
self.assertTrue(len(w) == 1)
1816+
self.assertEqual(w[0].category, FutureWarning)
18121817

18131818
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
18141819
with warnings.catch_warnings(record=True) as w:
18151820
m = pickle.loads(pickle.dumps(m))
1816-
self.assertTrue(len(w) == 0)
1821+
# warning from torch.load call in _load_from_bytes
1822+
self.assertTrue(len(w) == 1)
1823+
self.assertEqual(w[0].category, FutureWarning)
18171824

18181825
# Test whether loading from older checkpoints works without triggering warnings
18191826
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
18201827
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
18211828
with warnings.catch_warnings(record=True) as w:
18221829
m = pickle.loads(pickle.dumps(m))
1823-
self.assertTrue(len(w) == 0)
1830+
# warning from torch.load call in _load_from_bytes
1831+
self.assertTrue(len(w) == 1)
1832+
self.assertEqual(w[0].category, FutureWarning)
18241833

18251834
def test_weight_norm_pickle(self):
18261835
m = torch.nn.utils.weight_norm(nn.Linear(5, 7))

test/test_serialization.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import shutil
1616
import pathlib
1717
import platform
18-
from collections import OrderedDict
18+
from collections import namedtuple, OrderedDict
1919
from copy import deepcopy
2020
from itertools import product
2121

@@ -804,6 +804,17 @@ def wrapper(*args, **kwargs):
804804
def __exit__(self, *args, **kwargs):
805805
torch.save = self.torch_save
806806

807+
Point = namedtuple('Point', ['x', 'y'])
808+
809+
class ClassThatUsesBuildInstruction:
810+
def __init__(self, num):
811+
self.num = num
812+
813+
def __reduce_ex__(self, proto):
814+
# Third item, state here will cause pickle to push a BUILD instruction
815+
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}
816+
817+
807818
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
808819
class TestBothSerialization(TestCase):
809820
@parametrize("weights_only", (True, False))
@@ -826,7 +837,6 @@ def test(f_new, f_old):
826837
test(f_new, f_old)
827838
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
828839

829-
830840
class TestOldSerialization(TestCase, SerializationMixin):
831841
# unique_key is necessary because on Python 2.7, if a warning passed to
832842
# the warning module is the same, it is not raised again.
@@ -854,7 +864,8 @@ def import_module(name, filename):
854864
loaded = torch.load(checkpoint)
855865
self.assertTrue(isinstance(loaded, module.Net))
856866
if can_retrieve_source:
857-
self.assertEqual(len(w), 0)
867+
self.assertEqual(len(w), 1)
868+
self.assertEqual(w[0].category, FutureWarning)
858869

859870
# Replace the module with different source
860871
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
@@ -865,8 +876,8 @@ def import_module(name, filename):
865876
loaded = torch.load(checkpoint)
866877
self.assertTrue(isinstance(loaded, module.Net))
867878
if can_retrieve_source:
868-
self.assertEqual(len(w), 1)
869-
self.assertTrue(w[0].category, 'SourceChangeWarning')
879+
self.assertEqual(len(w), 2)
880+
self.assertTrue(w[1].category, 'SourceChangeWarning')
870881

871882
def test_serialization_container(self):
872883
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
@@ -1040,8 +1051,63 @@ def __reduce__(self):
10401051
self.assertIsNone(torch.load(f, weights_only=False))
10411052
f.seek(0)
10421053
# Safe load should assert
1043-
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
1054+
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
1055+
torch.load(f, weights_only=True)
1056+
try:
1057+
torch.serialization.add_safe_globals([print])
1058+
f.seek(0)
1059+
torch.load(f, weights_only=True)
1060+
finally:
1061+
torch.serialization.clear_safe_globals()
1062+
1063+
def test_weights_only_safe_globals_newobj(self):
1064+
# This will use NEWOBJ
1065+
p = Point(x=1, y=2)
1066+
with BytesIOContext() as f:
1067+
torch.save(p, f)
1068+
f.seek(0)
1069+
with self.assertRaisesRegex(pickle.UnpicklingError,
1070+
"GLOBAL __main__.Point was not an allowed global by default"):
10441071
torch.load(f, weights_only=True)
1072+
f.seek(0)
1073+
try:
1074+
torch.serialization.add_safe_globals([Point])
1075+
loaded_p = torch.load(f, weights_only=True)
1076+
self.assertEqual(loaded_p, p)
1077+
finally:
1078+
torch.serialization.clear_safe_globals()
1079+
1080+
def test_weights_only_safe_globals_build(self):
1081+
counter = 0
1082+
1083+
def fake_set_state(obj, *args):
1084+
nonlocal counter
1085+
counter += 1
1086+
1087+
c = ClassThatUsesBuildInstruction(2)
1088+
with BytesIOContext() as f:
1089+
torch.save(c, f)
1090+
f.seek(0)
1091+
with self.assertRaisesRegex(pickle.UnpicklingError,
1092+
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
1093+
torch.load(f, weights_only=True)
1094+
try:
1095+
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
1096+
# Test dict update path
1097+
f.seek(0)
1098+
loaded_c = torch.load(f, weights_only=True)
1099+
self.assertEqual(loaded_c.num, 2)
1100+
self.assertEqual(loaded_c.foo, 'bar')
1101+
# Test setstate path
1102+
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
1103+
f.seek(0)
1104+
loaded_c = torch.load(f, weights_only=True)
1105+
self.assertEqual(loaded_c.num, 2)
1106+
self.assertEqual(counter, 1)
1107+
self.assertFalse(hasattr(loaded_c, 'foo'))
1108+
finally:
1109+
torch.serialization.clear_safe_globals()
1110+
ClassThatUsesBuildInstruction.__setstate__ = None
10451111

10461112
@parametrize('weights_only', (False, True))
10471113
def test_serialization_math_bits(self, weights_only):

torch/_weights_only_unpickler.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# weights = torch.load(buf, weights_only = True)
2424

2525
import functools as _functools
26+
import warnings
2627
from collections import Counter, OrderedDict
2728
from pickle import (
2829
APPEND,
@@ -67,6 +68,16 @@
6768
from sys import maxsize
6869
from typing import Any, Dict, List
6970

71+
try:
72+
# We rely on this module in private cPython which provides dicts of
73+
# modules/functions that had their names changed from Python 2 to 3
74+
has_compat_pickle = True
75+
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
76+
except ImportError:
77+
# To prevent warning on import torch, we warn in the Unpickler.load below
78+
has_compat_pickle = False
79+
IMPORT_MAPPING, NAME_MAPPING = dict(), dict()
80+
7081
import torch
7182

7283
_marked_safe_globals_list: List[Any] = []
@@ -97,7 +108,8 @@ def _clear_safe_globals():
97108
def _get_user_allowed_globals():
98109
rc: Dict[str, Any] = {}
99110
for f in _marked_safe_globals_list:
100-
rc[f"{f.__module__}.{f.__name__}"] = f
111+
module, name = f.__module__, f.__name__
112+
rc[f"{module}.{name}"] = f
101113
return rc
102114

103115

@@ -170,12 +182,20 @@ def __init__(self, file, *, encoding: str = "bytes"):
170182
self.readline = file.readline
171183
self.read = file.read
172184
self.memo: Dict[int, Any] = {}
185+
self.proto: int = -1
173186

174187
def load(self):
175188
"""Read a pickled object representation from the open file.
176189
177190
Return the reconstituted object hierarchy specified in the file.
178191
"""
192+
if not has_compat_pickle:
193+
warnings.warn(
194+
"Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
195+
"If the default `pickle_protocol` was used at `torch.save` time, any functions or "
196+
"classes that are in these maps might not behave correctly if allowlisted via "
197+
"`torch.serialization.add_safe_globals()`."
198+
)
179199
self.metastack = []
180200
self.stack: List[Any] = []
181201
self.append = self.stack.append
@@ -190,6 +210,13 @@ def load(self):
190210
if key[0] == GLOBAL[0]:
191211
module = readline()[:-1].decode("utf-8")
192212
name = readline()[:-1].decode("utf-8")
213+
# Patch since torch.save default protocol is 2
214+
# users will be running this code in python > 3
215+
if self.proto == 2 and has_compat_pickle:
216+
if (module, name) in NAME_MAPPING:
217+
module, name = NAME_MAPPING[(module, name)]
218+
elif module in IMPORT_MAPPING:
219+
module = IMPORT_MAPPING[module]
193220
full_path = f"{module}.{name}"
194221
if full_path in _get_allowed_globals():
195222
self.append(_get_allowed_globals()[full_path])
@@ -204,9 +231,12 @@ def load(self):
204231
elif key[0] == NEWOBJ[0]:
205232
args = self.stack.pop()
206233
cls = self.stack.pop()
207-
if cls is not torch.nn.Parameter:
234+
if cls is torch.nn.Parameter:
235+
self.append(torch.nn.Parameter(*args))
236+
elif cls in _get_user_allowed_globals().values():
237+
self.append(cls.__new__(cls, *args))
238+
else:
208239
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
209-
self.append(torch.nn.Parameter(*args))
210240
elif key[0] == REDUCE[0]:
211241
args = self.stack.pop()
212242
func = self.stack[-1]
@@ -228,9 +258,14 @@ def load(self):
228258
inst.__setstate__(state)
229259
elif type(inst) is OrderedDict:
230260
inst.__dict__.update(state)
261+
elif type(inst) in _get_user_allowed_globals().values():
262+
if hasattr(inst, "__setstate__"):
263+
inst.__setstate__(state)
264+
else:
265+
inst.__dict__.update(state)
231266
else:
232267
raise RuntimeError(
233-
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
268+
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
234269
)
235270
# Stack manipulation
236271
elif key[0] == APPEND[0]:
@@ -334,8 +369,14 @@ def load(self):
334369
self.append(decode_long(data))
335370
# First and last deserializer ops
336371
elif key[0] == PROTO[0]:
337-
# Read and ignore proto version
338-
read(1)[0]
372+
self.proto = read(1)[0]
373+
if self.proto != 2:
374+
warnings.warn(
375+
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
376+
"not the default pickle protocol used by `torch.load` (2). The weights_only "
377+
"Unpickler might not support all instructions implemented by this protocol, "
378+
"please file an issue for adding support if you encounter this."
379+
)
339380
elif key[0] == STOP[0]:
340381
rc = self.stack.pop()
341382
return rc

torch/serialization.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,30 @@ def get_safe_globals() -> List[Any]:
165165
return _weights_only_unpickler._get_safe_globals()
166166

167167
def add_safe_globals(safe_globals: List[Any]) -> None:
168-
'''
169-
Marks the given globals as safe for ``weights_only`` load.
168+
"""
169+
Marks the given globals as safe for ``weights_only`` load. For example, functions
170+
added to this list can be called during unpickling, classes could be instantiated
171+
and have state set.
170172
171173
Args:
172174
safe_globals (List[Any]): list of globals to mark as safe
173-
'''
175+
176+
Example:
177+
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
178+
>>> import tempfile
179+
>>> class MyTensor(torch.Tensor):
180+
... pass
181+
>>> t = MyTensor(torch.randn(2, 3))
182+
>>> with tempfile.NamedTemporaryFile() as f:
183+
... torch.save(t, f.name)
184+
# Running `torch.load(f.name, weights_only=True)` will fail with
185+
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
186+
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
187+
... torch.serialization.add_safe_globals([MyTensor])
188+
... torch.load(f.name, weights_only=True)
189+
# MyTensor([[-0.5024, -1.8152, -0.5455],
190+
# [-0.8234, 2.0500, -0.3657]])
191+
"""
174192
_weights_only_unpickler._add_safe_globals(safe_globals)
175193

176194
def _is_zipfile(f) -> bool:
@@ -872,7 +890,7 @@ def load(
872890
map_location: MAP_LOCATION = None,
873891
pickle_module: Any = None,
874892
*,
875-
weights_only: bool = False,
893+
weights_only: Optional[bool] = None,
876894
mmap: Optional[bool] = None,
877895
**pickle_load_args: Any
878896
) -> Any:
@@ -982,6 +1000,11 @@ def load(
9821000
" with `weights_only` please check the recommended steps in the following error message."
9831001
" WeightsUnpickler error: "
9841002
)
1003+
if weights_only is None:
1004+
weights_only, warn_weights_only = False, True
1005+
else:
1006+
warn_weights_only = False
1007+
9851008
# Add ability to force safe only weight loads via environment variable
9861009
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
9871010
weights_only = True
@@ -991,6 +1014,20 @@ def load(
9911014
raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
9921015
else:
9931016
if pickle_module is None:
1017+
if warn_weights_only:
1018+
warnings.warn(
1019+
"You are using `torch.load` with `weights_only=False` (the current default value), which uses "
1020+
"the default pickle module implicitly. It is possible to construct malicious pickle data "
1021+
"which will execute arbitrary code during unpickling (See "
1022+
"https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
1023+
"In a future release, the default value for `weights_only` will be flipped to `True`. This "
1024+
"limits the functions that could be executed during unpickling. Arbitrary objects will no "
1025+
"longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
1026+
"user via `torch.serialization.add_safe_globals`. We recommend you start setting "
1027+
"`weights_only=True` for any use case where you don't have full control of the loaded file. "
1028+
"Please open an issue on GitHub for any issues related to this experimental feature.",
1029+
FutureWarning,
1030+
)
9941031
pickle_module = pickle
9951032

9961033
# make flipping default BC-compatible

0 commit comments

Comments
 (0)