Skip to content

Commit 51f2779

Browse files
committed
check whether onnx convolution weight exists in initializer
1 parent 65b14c9 commit 51f2779

1 file changed

Lines changed: 13 additions & 6 deletions

File tree

tools/onnx2daq/OnnxConverter.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ void OnnxConverter::AddConv(const string &input_name,
241241
dilations[1] * dilations[1] -
242242
input_shape[2];
243243
VLOG(5) << input_shape << ", " << pads << ", " << dilations << ", "
244-
<< new_pads;
244+
<< new_pads;
245245
// Why "AllowShortBlocksOnASingleLine: false" doesn't work on it?
246246
// clang-format off
247247
{
@@ -278,6 +278,9 @@ void OnnxConverter::AddConv(const string &input_name,
278278
return;
279279
}
280280

281+
if (!onnx_tensors_.has(ori_weight_name)) {
282+
throw std::invalid_argument("The weight of convolution must be known");
283+
}
281284
const auto &onnx_weight = onnx_tensors_.at(ori_weight_name);
282285
if (group == 1) {
283286
VLOG(5) << "Vanilla conv";
@@ -921,8 +924,8 @@ OnnxConverter::GetInputOfOnnxModel() {
921924
nnapi_shape = shape;
922925
}
923926
shaper_.AddShape(input.name(), nnapi_shape);
924-
const auto flat_input =
925-
DNN::CreateInputDirect(builder_, &nnapi_shape, input.name().c_str());
927+
const auto flat_input = DNN::CreateInputDirect(builder_, &nnapi_shape,
928+
input.name().c_str());
926929
inputs.push_back(flat_input);
927930
}
928931

@@ -1053,9 +1056,13 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
10531056
"Both dilations and strides > 1 is not supported for now"};
10541057
}
10551058
const auto weight_name = m(node.input(1));
1056-
const auto &onnx_weight = onnx_tensors_.at(weight_name);
1057-
if (group != 1 && onnx_weight.shape[1] != 1) {
1058-
return {false, "group != 1 is not supported"};
1059+
if (onnx_tensors_.has(weight_name)) {
1060+
const auto &onnx_weight = onnx_tensors_.at(weight_name);
1061+
if (group != 1 && onnx_weight.shape[1] != 1) {
1062+
return {false, "group != 1 is not supported"};
1063+
}
1064+
} else {
1065+
return {false, "The weight of convolution must be known"};
10591066
}
10601067
} else if (op == "AveragePool" || op == "MaxPool") {
10611068
const auto count_include_pad = helper.get("count_include_pad", 0);

0 commit comments

Comments
 (0)