Skip to content

Commit 873ddbb

Browse files
postrationalhouseroad
authored andcommitted
More extendable Runner (#1809)
* Make Runner class more flexible * Change @staticmethods to @classmethods in test Runner * Rename assert_similar_outputs and download_model functions
1 parent e18bb41 commit 873ddbb

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

onnx/backend/test/runner/__init__.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def tests(self): # type: () -> Type[unittest.TestCase]
163163
setattr(tests, name, item.func)
164164
return tests
165165

166-
@staticmethod
167-
def _assert_similar_outputs(ref_outputs, outputs, rtol, atol): # type: (Sequence[Any], Sequence[Any], float, float) -> None
166+
@classmethod
167+
def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol): # type: (Sequence[Any], Sequence[Any], float, float) -> None
168168
np.testing.assert_equal(len(ref_outputs), len(outputs))
169169
for i in range(len(outputs)):
170170
np.testing.assert_equal(ref_outputs[i].dtype, outputs[i].dtype)
@@ -174,9 +174,9 @@ def _assert_similar_outputs(ref_outputs, outputs, rtol, atol): # type: (Sequenc
174174
rtol=rtol,
175175
atol=atol)
176176

177-
@staticmethod
177+
@classmethod
178178
@retry_excute(3)
179-
def _download_model(model_test, model_dir, models_dir): # type: (TestCase, Text, Text) -> None
179+
def download_model(cls, model_test, model_dir, models_dir): # type: (TestCase, Text, Text) -> None
180180
# On Windows, NamedTemporaryFile can not be opened for a
181181
# second time
182182
download_file = tempfile.NamedTemporaryFile(delete=False)
@@ -196,8 +196,8 @@ def _download_model(model_test, model_dir, models_dir): # type: (TestCase, Text
196196
finally:
197197
os.remove(download_file.name)
198198

199-
@staticmethod
200-
def _prepare_model_data(model_test): # type: (TestCase) -> Text
199+
@classmethod
200+
def prepare_model_data(cls, model_test): # type: (TestCase) -> Text
201201
onnx_home = os.path.expanduser(os.getenv('ONNX_HOME', os.path.join('~', '.onnx')))
202202
models_dir = os.getenv('ONNX_MODELS',
203203
os.path.join(onnx_home, 'models'))
@@ -214,7 +214,7 @@ def _prepare_model_data(model_test): # type: (TestCase) -> Text
214214
break
215215
os.makedirs(model_dir)
216216

217-
Runner._download_model(model_test=model_test, model_dir=model_dir, models_dir=models_dir)
217+
cls.download_model(model_test=model_test, model_dir=model_dir, models_dir=models_dir)
218218
return model_dir
219219

220220
def _add_test(self,
@@ -262,7 +262,7 @@ def _add_model_test(self, model_test, kind): # type: (TestCase, Text) -> None
262262

263263
def run(test_self, device): # type: (Any, Text) -> None
264264
if model_test.model_dir is None:
265-
model_dir = Runner._prepare_model_data(model_test)
265+
model_dir = self.prepare_model_data(model_test)
266266
else:
267267
model_dir = model_test.model_dir
268268
model_pb_path = os.path.join(model_dir, 'model.onnx')
@@ -282,9 +282,9 @@ def run(test_self, device): # type: (Any, Text) -> None
282282
inputs = list(test_data['inputs'])
283283
outputs = list(prepared_model.run(inputs))
284284
ref_outputs = test_data['outputs']
285-
self._assert_similar_outputs(ref_outputs, outputs,
286-
rtol=model_test.rtol,
287-
atol=model_test.atol)
285+
self.assert_similar_outputs(ref_outputs, outputs,
286+
rtol=model_test.rtol,
287+
atol=model_test.atol)
288288

289289
for test_data_dir in glob.glob(
290290
os.path.join(model_dir, "test_data_set*")):
@@ -305,8 +305,8 @@ def run(test_self, device): # type: (Any, Text) -> None
305305
tensor.ParseFromString(f.read())
306306
ref_outputs.append(numpy_helper.to_array(tensor))
307307
outputs = list(prepared_model.run(inputs))
308-
self._assert_similar_outputs(ref_outputs, outputs,
309-
rtol=model_test.rtol,
310-
atol=model_test.atol)
308+
self.assert_similar_outputs(ref_outputs, outputs,
309+
rtol=model_test.rtol,
310+
atol=model_test.atol)
311311

312312
self._add_test(kind + 'Model', model_test.name, run, model_marker)

onnx/backend/test/stat_coverage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ def gen_model_test_coverage(schemas, f, ml):
129129
schema_dict = dict()
130130
for schema in schemas:
131131
schema_dict[schema.name] = schema
132-
# Load models from each model test using Runner._prepare_model_data
132+
# Load models from each model test using Runner.prepare_model_data
133133
# Need to grab associated nodes
134134
attrs = dict() # type: Dict[Text, Dict[Text, List[Any]]]
135135
model_paths = [] # type: List[Any]
136136
for rt in load_model_tests(kind='real'):
137-
model_dir = Runner._prepare_model_data(rt)
137+
model_dir = Runner.prepare_model_data(rt)
138138
model_paths.append(os.path.join(model_dir, 'model.onnx'))
139139
model_paths.sort()
140140
model_written = False

0 commit comments

Comments
 (0)