1515import shutil
1616import pathlib
1717import platform
18- from collections import namedtuple , OrderedDict
18+ from collections import OrderedDict
1919from copy import deepcopy
2020from itertools import product
2121
@@ -804,17 +804,6 @@ def wrapper(*args, **kwargs):
804804 def __exit__ (self , * args , ** kwargs ):
805805 torch .save = self .torch_save
806806
807- Point = namedtuple ('Point' , ['x' , 'y' ])
808-
809- class ClassThatUsesBuildInstruction :
810- def __init__ (self , num ):
811- self .num = num
812-
813- def __reduce_ex__ (self , proto ):
814- # Third item, state here will cause pickle to push a BUILD instruction
815- return ClassThatUsesBuildInstruction , (self .num ,), {'foo' : 'bar' }
816-
817-
818807@unittest .skipIf (IS_WINDOWS , "NamedTemporaryFile on windows" )
819808class TestBothSerialization (TestCase ):
820809 @parametrize ("weights_only" , (True , False ))
@@ -837,6 +826,7 @@ def test(f_new, f_old):
837826 test (f_new , f_old )
838827 self .assertTrue (len (w ) == 0 , msg = f"Expected no warnings but got { [str (x ) for x in w ]} " )
839828
829+
840830class TestOldSerialization (TestCase , SerializationMixin ):
841831 # unique_key is necessary because on Python 2.7, if a warning passed to
842832 # the warning module is the same, it is not raised again.
@@ -864,8 +854,7 @@ def import_module(name, filename):
864854 loaded = torch .load (checkpoint )
865855 self .assertTrue (isinstance (loaded , module .Net ))
866856 if can_retrieve_source :
867- self .assertEqual (len (w ), 1 )
868- self .assertEqual (w [0 ].category , FutureWarning )
857+ self .assertEqual (len (w ), 0 )
869858
870859 # Replace the module with different source
871860 fname = get_file_path_2 (os .path .dirname (os .path .dirname (torch .__file__ )), 'torch' , 'testing' ,
@@ -876,8 +865,8 @@ def import_module(name, filename):
876865 loaded = torch .load (checkpoint )
877866 self .assertTrue (isinstance (loaded , module .Net ))
878867 if can_retrieve_source :
879- self .assertEqual (len (w ), 2 )
880- self .assertTrue (w [1 ].category , 'SourceChangeWarning' )
868+ self .assertEqual (len (w ), 1 )
869+ self .assertTrue (w [0 ].category , 'SourceChangeWarning' )
881870
882871 def test_serialization_container (self ):
883872 self ._test_serialization_container ('file' , tempfile .NamedTemporaryFile )
@@ -1051,63 +1040,8 @@ def __reduce__(self):
10511040 self .assertIsNone (torch .load (f , weights_only = False ))
10521041 f .seek (0 )
10531042 # Safe load should assert
1054- with self .assertRaisesRegex (pickle .UnpicklingError , "Unsupported global: GLOBAL builtins.print" ):
1055- torch .load (f , weights_only = True )
1056- try :
1057- torch .serialization .add_safe_globals ([print ])
1058- f .seek (0 )
1059- torch .load (f , weights_only = True )
1060- finally :
1061- torch .serialization .clear_safe_globals ()
1062-
1063- def test_weights_only_safe_globals_newobj (self ):
1064- # This will use NEWOBJ
1065- p = Point (x = 1 , y = 2 )
1066- with BytesIOContext () as f :
1067- torch .save (p , f )
1068- f .seek (0 )
1069- with self .assertRaisesRegex (pickle .UnpicklingError ,
1070- "GLOBAL __main__.Point was not an allowed global by default" ):
1043+ with self .assertRaisesRegex (pickle .UnpicklingError , "Unsupported global: GLOBAL __builtin__.print" ):
10711044 torch .load (f , weights_only = True )
1072- f .seek (0 )
1073- try :
1074- torch .serialization .add_safe_globals ([Point ])
1075- loaded_p = torch .load (f , weights_only = True )
1076- self .assertEqual (loaded_p , p )
1077- finally :
1078- torch .serialization .clear_safe_globals ()
1079-
1080- def test_weights_only_safe_globals_build (self ):
1081- counter = 0
1082-
1083- def fake_set_state (obj , * args ):
1084- nonlocal counter
1085- counter += 1
1086-
1087- c = ClassThatUsesBuildInstruction (2 )
1088- with BytesIOContext () as f :
1089- torch .save (c , f )
1090- f .seek (0 )
1091- with self .assertRaisesRegex (pickle .UnpicklingError ,
1092- "GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default" ):
1093- torch .load (f , weights_only = True )
1094- try :
1095- torch .serialization .add_safe_globals ([ClassThatUsesBuildInstruction ])
1096- # Test dict update path
1097- f .seek (0 )
1098- loaded_c = torch .load (f , weights_only = True )
1099- self .assertEqual (loaded_c .num , 2 )
1100- self .assertEqual (loaded_c .foo , 'bar' )
1101- # Test setstate path
1102- ClassThatUsesBuildInstruction .__setstate__ = fake_set_state
1103- f .seek (0 )
1104- loaded_c = torch .load (f , weights_only = True )
1105- self .assertEqual (loaded_c .num , 2 )
1106- self .assertEqual (counter , 1 )
1107- self .assertFalse (hasattr (loaded_c , 'foo' ))
1108- finally :
1109- torch .serialization .clear_safe_globals ()
1110- ClassThatUsesBuildInstruction .__setstate__ = None
11111045
11121046 @parametrize ('weights_only' , (False , True ))
11131047 def test_serialization_math_bits (self , weights_only ):
0 commit comments