Skip to content

Commit 38dd3bd

Browse files
committed
Support onnx input (which can be read from both output and initializer)
1 parent 8cab557 commit 38dd3bd

5 files changed

Lines changed: 333 additions & 63 deletions

File tree

generate_code.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ def infer_cfg(cfg, target: Target):
9898
op['input'] = []
9999
if 'base_input_num' not in op or op['base_input_num'] == 1:
100100
op['input'].insert(0,
101-
{'name': 'input', 'nnapi_type': 'tensor', 'cpp_type': 'str', 'needed_by_shaper': True})
101+
{'name': 'input', 'nnapi_type': 'tensor', 'cpp_type': 'str', 'input': True, 'needed_by_shaper': True})
102102
elif op['base_input_num'] == 2:
103-
op['input'] = [{'name': 'input1', 'nnapi_type': 'tensor', 'cpp_type': 'str', 'needed_by_shaper': True},
104-
{'name': 'input2', 'nnapi_type': 'tensor', 'cpp_type': 'str', 'needed_by_shaper': True}] \
103+
op['input'] = [{'name': 'input1', 'nnapi_type': 'tensor', 'cpp_type': 'str', 'input': True, 'needed_by_shaper': True},
104+
{'name': 'input2', 'nnapi_type': 'tensor', 'cpp_type': 'str', 'input': True, 'needed_by_shaper': True}] \
105105
+ op['input']
106106
elif op['base_input_num'] == 'n':
107107
op['input'].insert(0,
108-
{'name': 'inputs', 'nnapi_type': 'tensor', 'cpp_type': 'str_list',
108+
{'name': 'inputs', 'nnapi_type': 'tensor', 'cpp_type': 'str_list', 'input': True,
109109
'needed_by_shaper': True})
110110
elif op['base_input_num'] == 0:
111111
pass
@@ -145,11 +145,12 @@ def infer_cfg(cfg, target: Target):
145145
ipt['name'] = 'bias'
146146
ipt['nnapi_type'] = 'tensor'
147147
ipt['cpp_type'] = 'optional_str'
148-
ipt['learnable'] = True
149-
if 'learnable' not in ipt:
150-
ipt['learnable'] = False
151-
if ipt['learnable'] and 'convert_func' not in ipt:
148+
ipt['input'] = True
152149
ipt['convert_func'] = 'OnnxToNnapiIdentity'
150+
if 'input' not in ipt:
151+
ipt['input'] = False
152+
if 'convert_func' not in ipt:
153+
ipt['convert_func'] = 'OnnxToNnapiAxes0231'
153154
if 'needed_by_shaper' not in ipt:
154155
ipt['needed_by_shaper'] = False
155156

@@ -186,20 +187,28 @@ def generate_onnx_converter():
186187
if op['fused']:
187188
cogoutl(f"const auto activation = FindActivation(model_proto_, output);")
188189
for x in op['input']:
189-
if x['learnable']:
190-
assert x['cpp_type'] in ['str', 'optional_str']
190+
if x['input']:
191191
if x['cpp_type'] == 'str':
192-
cogoutl(f"""{{
193-
const auto name = {x['name']};""")
192+
cogoutl(f"""
193+
{{
194+
const auto name = {x['name']};""")
194195
elif x['cpp_type'] == 'optional_str':
195-
cogoutl(f"""if ({x['name']}.has_value()) {{
196-
const auto name = {x['name']}.value();""")
197-
cogoutl(f"""const auto &onnx_tensor = onnx_tensors_.at(name);
198-
const auto new_tensor = {x['convert_func']}(onnx_tensor);
199-
shaper_.AddShape(name, new_tensor.shape);
200-
nnapi_tensors_[name] = new_tensor;
201-
CreateTensorFb(name, new_tensor);""")
202-
cogoutl("}")
196+
cogoutl(f"""
197+
if ({x['name']}.has_value()) {{
198+
const auto name = {x['name']}.value();""")
199+
elif x['cpp_type'] == 'str_list':
200+
cogoutl(f"""
201+
for (const auto &name : {x['name']}) {{""")
202+
cogoutl(f"""
203+
if (onnx_tensors_.has(name)) {{
204+
const auto &onnx_tensor = onnx_tensors_.at(name);
205+
const auto new_tensor = {x['convert_func']}(onnx_tensor);
206+
shaper_.AddShape(name, new_tensor.shape);
207+
nnapi_tensors_[name] = new_tensor;
208+
CreateTensorFb(name, new_tensor);
209+
}}
210+
}}
211+
""")
203212
if x['cpp_type'] == 'str_list':
204213
cogoutl(f"const auto {x['name']}_fb = FbStrVector({x['name']});")
205214

include/dnnlibrary/ModelBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
#include <string>
1414
#include <vector>
1515

16-
#include <common/data_types.h>
1716
#include <common/Shaper.h>
1817
#include <common/StrKeyMap.h>
18+
#include <common/data_types.h>
1919
#include <dnnlibrary/Model.h>
2020
#include <dnnlibrary/NeuralNetworksWrapper.h>
2121

include/tools/onnx2daq/OnnxConverter.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,21 @@ class OnnxConverter {
158158
// OnnxConverter auto generated methods end
159159

160160
/**
161+
* transpose axes to [1, 2, 3, 0]
162+
* for onnx dw conv weight to nnapi dw conv weight
161163
* onnx: [filter_out_channel, filter_in_channel / group, height, width]
162164
* nnapi: [1, height, width, depth_out]
163165
*/
164-
Tensor OnnxToNnapiDwConvWeight(const Tensor &src);
166+
Tensor OnnxToNnapiAxes1230(const Tensor &src);
165167

166168
/**
169+
* transpose axes to [0, 2, 3, 1]
170+
* for nchw (onnx) -> nhwc (nnapi)
171+
* or onnx conv weight to nnapi conv (not dw conv) weight:
167172
* onnx: [filter_out_channel, filter_in_channel, height, width]
168173
* nnapi: [depth_out, height, width, depth_in]
169174
*/
170-
Tensor OnnxToNnapiVanillaConvWeight(const Tensor &src);
175+
Tensor OnnxToNnapiAxes0231(const Tensor &src);
171176

172177
/**
173178
* Just return the same tensor

ops.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
-
88
name: weight
99
nnapi_type: tensor
10-
# "learnable" stands for the tensor is read from serialized model file
11-
learnable: true
12-
convert_func: OnnxToNnapiVanillaConvWeight
10+
# "input" stands for onnx input instead of attribute (TODO: a new name)
11+
input: true
1312
cpp_type: str
1413
needed_by_shaper: true
1514
-
@@ -202,7 +201,7 @@
202201
-
203202
name: weight
204203
nnapi_type: tensor
205-
learnable: true
204+
input: true
206205
cpp_type: str
207206
needed_by_shaper: true
208207
-
@@ -238,8 +237,8 @@
238237
-
239238
name: weight
240239
nnapi_type: tensor
241-
learnable: true
242-
convert_func: OnnxToNnapiDwConvWeight
240+
input: true
241+
convert_func: OnnxToNnapiAxes1230
243242
cpp_type: str
244243
needed_by_shaper: true
245244
-

0 commit comments

Comments
 (0)