Skip to content

Commit 8f05911

Browse files
committed
Fix bug in validate_onnx.py -- support non 4d input
1 parent f2fff56 commit 8f05911

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

validate_onnx.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ def finish(model):
2121
def run(input_arrs, daq, dnn_retrieve_result, quant_input=False, quant_output=False):
2222
input_txts = []
2323
for i, input_arr in enumerate(input_arrs):
24-
nchw_shape = input_arr.shape
25-
nhwc_shape = (nchw_shape[0], nchw_shape[2], nchw_shape[3], nchw_shape[1])
26-
nhwc_input = np.moveaxis(input_arr, 1, -1)
27-
assert nhwc_input.shape == nhwc_shape
24+
if len(input_arr.shape) == 4:
25+
nchw_shape = input_arr.shape
26+
nhwc_shape = (nchw_shape[0], nchw_shape[2], nchw_shape[3], nchw_shape[1])
27+
input_arr = np.moveaxis(input_arr, 1, -1)
28+
assert input_arr.shape == nhwc_shape
2829
input_txt = 'input{}.txt'.format(i)
29-
np.savetxt(input_txt, nhwc_input.flatten(), delimiter='\n')
30+
np.savetxt(input_txt, input_arr.flatten(), delimiter='\n')
3031
input_txts.append(input_txt)
3132
input_txts_arg = " ".join(input_txts)
3233
input_txts_in_android_arg = " ".join(map(lambda x: "/data/local/tmp/" + x, input_txts))

0 commit comments

Comments
 (0)