@@ -163,8 +163,8 @@ def tests(self): # type: () -> Type[unittest.TestCase]
163163 setattr (tests , name , item .func )
164164 return tests
165165
166- @staticmethod
167- def _assert_similar_outputs ( ref_outputs , outputs , rtol , atol ): # type: (Sequence[Any], Sequence[Any], float, float) -> None
166+ @classmethod
167+ def assert_similar_outputs ( cls , ref_outputs , outputs , rtol , atol ): # type: (Sequence[Any], Sequence[Any], float, float) -> None
168168 np .testing .assert_equal (len (ref_outputs ), len (outputs ))
169169 for i in range (len (outputs )):
170170 np .testing .assert_equal (ref_outputs [i ].dtype , outputs [i ].dtype )
@@ -174,9 +174,9 @@ def _assert_similar_outputs(ref_outputs, outputs, rtol, atol): # type: (Sequenc
174174 rtol = rtol ,
175175 atol = atol )
176176
177- @staticmethod
177+ @classmethod
178178 @retry_excute (3 )
179- def _download_model ( model_test , model_dir , models_dir ): # type: (TestCase, Text, Text) -> None
179+ def download_model ( cls , model_test , model_dir , models_dir ): # type: (TestCase, Text, Text) -> None
180180 # On Windows, NamedTemporaryFile can not be opened for a
181181 # second time
182182 download_file = tempfile .NamedTemporaryFile (delete = False )
@@ -196,8 +196,8 @@ def _download_model(model_test, model_dir, models_dir): # type: (TestCase, Text
196196 finally :
197197 os .remove (download_file .name )
198198
199- @staticmethod
200- def _prepare_model_data ( model_test ): # type: (TestCase) -> Text
199+ @classmethod
200+ def prepare_model_data ( cls , model_test ): # type: (TestCase) -> Text
201201 onnx_home = os .path .expanduser (os .getenv ('ONNX_HOME' , os .path .join ('~' , '.onnx' )))
202202 models_dir = os .getenv ('ONNX_MODELS' ,
203203 os .path .join (onnx_home , 'models' ))
@@ -214,7 +214,7 @@ def _prepare_model_data(model_test): # type: (TestCase) -> Text
214214 break
215215 os .makedirs (model_dir )
216216
217- Runner . _download_model (model_test = model_test , model_dir = model_dir , models_dir = models_dir )
217+ cls . download_model (model_test = model_test , model_dir = model_dir , models_dir = models_dir )
218218 return model_dir
219219
220220 def _add_test (self ,
@@ -262,7 +262,7 @@ def _add_model_test(self, model_test, kind): # type: (TestCase, Text) -> None
262262
263263 def run (test_self , device ): # type: (Any, Text) -> None
264264 if model_test .model_dir is None :
265- model_dir = Runner . _prepare_model_data (model_test )
265+ model_dir = self . prepare_model_data (model_test )
266266 else :
267267 model_dir = model_test .model_dir
268268 model_pb_path = os .path .join (model_dir , 'model.onnx' )
@@ -282,9 +282,9 @@ def run(test_self, device): # type: (Any, Text) -> None
282282 inputs = list (test_data ['inputs' ])
283283 outputs = list (prepared_model .run (inputs ))
284284 ref_outputs = test_data ['outputs' ]
285- self ._assert_similar_outputs (ref_outputs , outputs ,
286- rtol = model_test .rtol ,
287- atol = model_test .atol )
285+ self .assert_similar_outputs (ref_outputs , outputs ,
286+ rtol = model_test .rtol ,
287+ atol = model_test .atol )
288288
289289 for test_data_dir in glob .glob (
290290 os .path .join (model_dir , "test_data_set*" )):
@@ -305,8 +305,8 @@ def run(test_self, device): # type: (Any, Text) -> None
305305 tensor .ParseFromString (f .read ())
306306 ref_outputs .append (numpy_helper .to_array (tensor ))
307307 outputs = list (prepared_model .run (inputs ))
308- self ._assert_similar_outputs (ref_outputs , outputs ,
309- rtol = model_test .rtol ,
310- atol = model_test .atol )
308+ self .assert_similar_outputs (ref_outputs , outputs ,
309+ rtol = model_test .rtol ,
310+ atol = model_test .atol )
311311
312312 self ._add_test (kind + 'Model' , model_test .name , run , model_marker )
0 commit comments