@@ -21,12 +21,13 @@ def finish(model):
2121def run (input_arrs , daq , dnn_retrieve_result , quant_input = False , quant_output = False ):
2222 input_txts = []
2323 for i , input_arr in enumerate (input_arrs ):
24- nchw_shape = input_arr .shape
25- nhwc_shape = (nchw_shape [0 ], nchw_shape [2 ], nchw_shape [3 ], nchw_shape [1 ])
26- nhwc_input = np .moveaxis (input_arr , 1 , - 1 )
27- assert nhwc_input .shape == nhwc_shape
24+ if len (input_arr .shape ) == 4 :
25+ nchw_shape = input_arr .shape
26+ nhwc_shape = (nchw_shape [0 ], nchw_shape [2 ], nchw_shape [3 ], nchw_shape [1 ])
27+ input_arr = np .moveaxis (input_arr , 1 , - 1 )
28+ assert input_arr .shape == nhwc_shape
2829 input_txt = 'input{}.txt' .format (i )
29- np .savetxt (input_txt , nhwc_input .flatten (), delimiter = '\n ' )
30+ np .savetxt (input_txt , input_arr .flatten (), delimiter = '\n ' )
3031 input_txts .append (input_txt )
3132 input_txts_arg = " " .join (input_txts )
3233 input_txts_in_android_arg = " " .join (map (lambda x : "/data/local/tmp/" + x , input_txts ))
0 commit comments