1515import shutil
1616import pathlib
1717import platform
18- from collections import OrderedDict
18+ from collections import namedtuple , OrderedDict
1919from copy import deepcopy
2020from itertools import product
2121
@@ -804,6 +804,17 @@ def wrapper(*args, **kwargs):
804804 def __exit__ (self , * args , ** kwargs ):
805805 torch .save = self .torch_save
806806
807+ Point = namedtuple ('Point' , ['x' , 'y' ])
808+
809+ class ClassThatUsesBuildInstruction :
810+ def __init__ (self , num ):
811+ self .num = num
812+
813+ def __reduce_ex__ (self , proto ):
814+ # Third item, state here will cause pickle to push a BUILD instruction
815+ return ClassThatUsesBuildInstruction , (self .num ,), {'foo' : 'bar' }
816+
817+
807818@unittest .skipIf (IS_WINDOWS , "NamedTemporaryFile on windows" )
808819class TestBothSerialization (TestCase ):
809820 @parametrize ("weights_only" , (True , False ))
@@ -826,7 +837,6 @@ def test(f_new, f_old):
826837 test (f_new , f_old )
827838 self .assertTrue (len (w ) == 0 , msg = f"Expected no warnings but got { [str (x ) for x in w ]} " )
828839
829-
830840class TestOldSerialization (TestCase , SerializationMixin ):
831841 # unique_key is necessary because on Python 2.7, if a warning passed to
832842 # the warning module is the same, it is not raised again.
@@ -854,7 +864,8 @@ def import_module(name, filename):
854864 loaded = torch .load (checkpoint )
855865 self .assertTrue (isinstance (loaded , module .Net ))
856866 if can_retrieve_source :
857- self .assertEqual (len (w ), 0 )
867+ self .assertEqual (len (w ), 1 )
868+ self .assertEqual (w [0 ].category , FutureWarning )
858869
859870 # Replace the module with different source
860871 fname = get_file_path_2 (os .path .dirname (os .path .dirname (torch .__file__ )), 'torch' , 'testing' ,
@@ -865,8 +876,8 @@ def import_module(name, filename):
865876 loaded = torch .load (checkpoint )
866877 self .assertTrue (isinstance (loaded , module .Net ))
867878 if can_retrieve_source :
868- self .assertEqual (len (w ), 1 )
869- self .assertTrue (w [0 ].category , 'SourceChangeWarning' )
879+ self .assertEqual (len (w ), 2 )
880+ self .assertTrue (w [1 ].category , 'SourceChangeWarning' )
870881
871882 def test_serialization_container (self ):
872883 self ._test_serialization_container ('file' , tempfile .NamedTemporaryFile )
@@ -1040,8 +1051,63 @@ def __reduce__(self):
10401051 self .assertIsNone (torch .load (f , weights_only = False ))
10411052 f .seek (0 )
10421053 # Safe load should assert
1043- with self .assertRaisesRegex (pickle .UnpicklingError , "Unsupported global: GLOBAL __builtin__.print" ):
1054+ with self .assertRaisesRegex (pickle .UnpicklingError , "Unsupported global: GLOBAL builtins.print" ):
1055+ torch .load (f , weights_only = True )
1056+ try :
1057+ torch .serialization .add_safe_globals ([print ])
1058+ f .seek (0 )
1059+ torch .load (f , weights_only = True )
1060+ finally :
1061+ torch .serialization .clear_safe_globals ()
1062+
1063+ def test_weights_only_safe_globals_newobj (self ):
1064+ # This will use NEWOBJ
1065+ p = Point (x = 1 , y = 2 )
1066+ with BytesIOContext () as f :
1067+ torch .save (p , f )
1068+ f .seek (0 )
1069+ with self .assertRaisesRegex (pickle .UnpicklingError ,
1070+ "GLOBAL __main__.Point was not an allowed global by default" ):
10441071 torch .load (f , weights_only = True )
1072+ f .seek (0 )
1073+ try :
1074+ torch .serialization .add_safe_globals ([Point ])
1075+ loaded_p = torch .load (f , weights_only = True )
1076+ self .assertEqual (loaded_p , p )
1077+ finally :
1078+ torch .serialization .clear_safe_globals ()
1079+
1080+ def test_weights_only_safe_globals_build (self ):
1081+ counter = 0
1082+
1083+ def fake_set_state (obj , * args ):
1084+ nonlocal counter
1085+ counter += 1
1086+
1087+ c = ClassThatUsesBuildInstruction (2 )
1088+ with BytesIOContext () as f :
1089+ torch .save (c , f )
1090+ f .seek (0 )
1091+ with self .assertRaisesRegex (pickle .UnpicklingError ,
1092+ "GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default" ):
1093+ torch .load (f , weights_only = True )
1094+ try :
1095+ torch .serialization .add_safe_globals ([ClassThatUsesBuildInstruction ])
1096+ # Test dict update path
1097+ f .seek (0 )
1098+ loaded_c = torch .load (f , weights_only = True )
1099+ self .assertEqual (loaded_c .num , 2 )
1100+ self .assertEqual (loaded_c .foo , 'bar' )
1101+ # Test setstate path
1102+ ClassThatUsesBuildInstruction .__setstate__ = fake_set_state
1103+ f .seek (0 )
1104+ loaded_c = torch .load (f , weights_only = True )
1105+ self .assertEqual (loaded_c .num , 2 )
1106+ self .assertEqual (counter , 1 )
1107+ self .assertFalse (hasattr (loaded_c , 'foo' ))
1108+ finally :
1109+ torch .serialization .clear_safe_globals ()
1110+ ClassThatUsesBuildInstruction .__setstate__ = None
10451111
10461112 @parametrize ('weights_only' , (False , True ))
10471113 def test_serialization_math_bits (self , weights_only ):
0 commit comments