@@ -98,14 +98,14 @@ def infer_cfg(cfg, target: Target):
9898 op ['input' ] = []
9999 if 'base_input_num' not in op or op ['base_input_num' ] == 1 :
100100 op ['input' ].insert (0 ,
101- {'name' : 'input' , 'nnapi_type' : 'tensor' , 'cpp_type' : 'str' , 'needed_by_shaper' : True })
101+ {'name' : 'input' , 'nnapi_type' : 'tensor' , 'cpp_type' : 'str' , 'input' : True , ' needed_by_shaper' : True })
102102 elif op ['base_input_num' ] == 2 :
103- op ['input' ] = [{'name' : 'input1' , 'nnapi_type' : 'tensor' , 'cpp_type' : 'str' , 'needed_by_shaper' : True },
104- {'name' : 'input2' , 'nnapi_type' : 'tensor' , 'cpp_type' : 'str' , 'needed_by_shaper' : True }] \
103+ op ['input' ] = [{'name' : 'input1' , 'nnapi_type' : 'tensor' , 'cpp_type' : 'str' , 'input' : True , ' needed_by_shaper' : True },
104+ {'name' : 'input2' , 'nnapi_type' : 'tensor' , 'cpp_type' : 'str' , 'input' : True , ' needed_by_shaper' : True }] \
105105 + op ['input' ]
106106 elif op ['base_input_num' ] == 'n' :
107107 op ['input' ].insert (0 ,
108- {'name' : 'inputs' , 'nnapi_type' : 'tensor' , 'cpp_type' : 'str_list' ,
108+ {'name' : 'inputs' , 'nnapi_type' : 'tensor' , 'cpp_type' : 'str_list' , 'input' : True ,
109109 'needed_by_shaper' : True })
110110 elif op ['base_input_num' ] == 0 :
111111 pass
@@ -145,11 +145,12 @@ def infer_cfg(cfg, target: Target):
145145 ipt ['name' ] = 'bias'
146146 ipt ['nnapi_type' ] = 'tensor'
147147 ipt ['cpp_type' ] = 'optional_str'
148- ipt ['learnable' ] = True
149- if 'learnable' not in ipt :
150- ipt ['learnable' ] = False
151- if ipt ['learnable' ] and 'convert_func' not in ipt :
148+ ipt ['input' ] = True
152149 ipt ['convert_func' ] = 'OnnxToNnapiIdentity'
150+ if 'input' not in ipt :
151+ ipt ['input' ] = False
152+ if 'convert_func' not in ipt :
153+ ipt ['convert_func' ] = 'OnnxToNnapiAxes0231'
153154 if 'needed_by_shaper' not in ipt :
154155 ipt ['needed_by_shaper' ] = False
155156
@@ -186,20 +187,28 @@ def generate_onnx_converter():
186187 if op ['fused' ]:
187188 cogoutl (f"const auto activation = FindActivation(model_proto_, output);" )
188189 for x in op ['input' ]:
189- if x ['learnable' ]:
190- assert x ['cpp_type' ] in ['str' , 'optional_str' ]
190+ if x ['input' ]:
191191 if x ['cpp_type' ] == 'str' :
192- cogoutl (f"""{{
193- const auto name = { x ['name' ]} ;""" )
192+ cogoutl (f"""
193+ {{
194+ const auto name = { x ['name' ]} ;""" )
194195 elif x ['cpp_type' ] == 'optional_str' :
195- cogoutl (f"""if ({ x ['name' ]} .has_value()) {{
196- const auto name = { x ['name' ]} .value();""" )
197- cogoutl (f"""const auto &onnx_tensor = onnx_tensors_.at(name);
198- const auto new_tensor = { x ['convert_func' ]} (onnx_tensor);
199- shaper_.AddShape(name, new_tensor.shape);
200- nnapi_tensors_[name] = new_tensor;
201- CreateTensorFb(name, new_tensor);""" )
202- cogoutl ("}" )
196+ cogoutl (f"""
197+ if ({ x ['name' ]} .has_value()) {{
198+ const auto name = { x ['name' ]} .value();""" )
199+ elif x ['cpp_type' ] == 'str_list' :
200+ cogoutl (f"""
201+ for (const auto &name : { x ['name' ]} ) {{""" )
202+ cogoutl (f"""
203+ if (onnx_tensors_.has(name)) {{
204+ const auto &onnx_tensor = onnx_tensors_.at(name);
205+ const auto new_tensor = { x ['convert_func' ]} (onnx_tensor);
206+ shaper_.AddShape(name, new_tensor.shape);
207+ nnapi_tensors_[name] = new_tensor;
208+ CreateTensorFb(name, new_tensor);
209+ }}
210+ }}
211+ """ )
203212 if x ['cpp_type' ] == 'str_list' :
204213 cogoutl (f"const auto { x ['name' ]} _fb = FbStrVector({ x ['name' ]} );" )
205214
0 commit comments