Skip to content

Commit b33f375

Browse files
authored
Merge pull request #54 from JDAI-CV/getsupportednodes
Get supported nodes
2 parents 2a705eb + 869d560 commit b33f375

File tree

3 files changed

+135
-5
lines changed

3 files changed

+135
-5
lines changed

include/common/StrKeyMap.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class StrKeyMap {
2323
inline V &operator[](const std::string &key) {
2424
return map_[key];
2525
}
26-
inline const V &at(const std::string &key) {
26+
inline const V &at(const std::string &key) const {
2727
try {
2828
return map_.at(key);
2929
} catch (const std::out_of_range &e) {

include/tools/onnx2daq/OnnxConverter.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#include <common/data_types.h>
21
#include <common/Shaper.h>
32
#include <common/StrKeyMap.h>
43
#include <common/daq_generated.h>
4+
#include <common/data_types.h>
55
#include <glog/logging.h>
66
#include <onnx/onnx_pb.h>
77

@@ -52,7 +52,7 @@ class OnnxConverter {
5252

5353
std::map<std::string, std::string> name_map_;
5454

55-
std::string m(const std::string &str);
55+
std::string m(const std::string &str) const;
5656

5757
ONNX_NAMESPACE::ModelProto model_proto_;
5858
flatbuffers::FlatBufferBuilder builder_;
@@ -83,6 +83,9 @@ class OnnxConverter {
8383
void ReadTableFile(const std::string &table_file);
8484
std::vector<flatbuffers::Offset<DNN::QuantInfo>> ConvertQuantInfosToFbs();
8585

86+
std::pair<bool, std::string> IsNodeSupported(
87+
const ONNX_NAMESPACE::NodeProto &node_proto) const;
88+
8689
void AddConv(const std::string &input_name, const std::vector<int> &strides,
8790
const std::vector<int> &pads,
8891
const std::vector<int> &dilations, int group,
@@ -171,7 +174,11 @@ class OnnxConverter {
171174
*/
172175
Tensor OnnxToNnapiIdentity(const Tensor &src);
173176

177+
void Clear();
178+
174179
public:
180+
std::vector<std::vector<int>> GetSupportedNodes(
181+
const ONNX_NAMESPACE::ModelProto &model);
175182
void Convert(const std::string &model_str, const std::string &filepath,
176183
const std::string &table_file = "");
177184
void Convert(const ONNX_NAMESPACE::ModelProto &model,

tools/onnx2daq/OnnxConverter.cpp

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ using std::vector;
1717
using Shape = Shaper::Shape;
1818

1919
namespace dnn {
20-
std::string OnnxConverter::m(const std::string &str) {
20+
std::string OnnxConverter::m(const std::string &str) const {
2121
if (name_map_.find(str) != name_map_.end()) {
22-
return name_map_[str];
22+
return name_map_.at(str);
2323
}
2424

2525
return str;
@@ -751,6 +751,125 @@ void OnnxConverter::Convert(const std::string &model_str,
751751
Save(filepath);
752752
}
753753

754+
std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
755+
const ONNX_NAMESPACE::NodeProto &node) const {
756+
NodeAttrHelper helper(node);
757+
const auto &op = node.op_type();
758+
const std::vector<std::string> supported_types{
759+
"Conv", "AveragePool",
760+
"MaxPool", "GlobalAveragePool",
761+
"GlobalMaxPool", "Relu",
762+
"PRelu", "Add",
763+
"Mul", "Gemm",
764+
"Softmax", "Concat",
765+
"Dropout", "BatchNormalization",
766+
"Reshape", "LRN",
767+
"Identity"};
768+
if (std::find(supported_types.begin(), supported_types.end(), op) ==
769+
supported_types.end()) {
770+
return {false, "Unsupported operator " + op};
771+
}
772+
if (op == "Conv") {
773+
const auto strides = helper.get("strides", vector<int>{1, 1});
774+
const auto pads = helper.get("pads", vector<int>{0, 0, 0, 0});
775+
const auto dilations = helper.get("dilations", vector<int>{1, 1});
776+
CHECK_EQ(pads.size(), 4ul);
777+
CHECK_EQ(strides.size(), 2ul);
778+
CHECK_EQ(dilations.size(), 2ul);
779+
const auto group = helper.get("group", 1);
780+
if (dilations != vector<int>{1, 1} && strides != vector<int>{1, 1}) {
781+
return {false,
782+
"Both dilations and strides > 1 is not supported for now"};
783+
}
784+
const auto weight_name = m(node.input(1));
785+
const auto &onnx_weight = onnx_tensors_.at(weight_name);
786+
if (group != 1 && onnx_weight.shape[1] != 1) {
787+
return {false, "group != 1 is not supported"};
788+
}
789+
} else if (op == "AveragePool" || op == "MaxPool") {
790+
const auto count_include_pad = helper.get("count_include_pad", 0);
791+
if (count_include_pad == 1) {
792+
return {false, "count_include_pad == 1 is not supported"};
793+
}
794+
const auto storage_order = helper.get("storage_order", 0);
795+
if (storage_order == 1) {
796+
return {false, "storage_order == 1 is not supported"};
797+
}
798+
if (helper.get("auto_pad", "NOTSET") != "NOTSET") {
799+
return {false, "auto_pad is not supported"};
800+
}
801+
} else if (op == "PRelu") {
802+
const auto slope_name = m(node.input(1));
803+
if (onnx_tensors_.at(slope_name).shape != Shape{1}) {
804+
// TODO: support it
805+
return {false, "Only support one element slope."};
806+
}
807+
} else if (op == "Gemm") {
808+
const auto transA = helper.get("transA", 0);
809+
const auto transB = helper.get("transB", 0);
810+
const auto alpha = helper.get("alpha", 1.0f);
811+
const auto beta = helper.get("beta", 1.0f);
812+
if (!(transA == 0 && transB == 1 && alpha == 1.f && beta == 1.f)) {
813+
return {false,
814+
"Only transA == 0, transB == 1, alpha == 1.0 and beta == "
815+
"1.0 is supported."};
816+
}
817+
} else if (op == "BatchNormalization") {
818+
if (node.output_size() != 1) {
819+
return {false,
820+
"Your onnx model may be in training mode, please export "
821+
"it in test mode."};
822+
}
823+
} else if (op == "LRN") {
824+
const auto size = helper.get("size", 1);
825+
if (size % 2 == 0) {
826+
return {false, "NNAPI only support odd size for LRN"};
827+
}
828+
} else if (op == "Reshape") {
829+
const auto output_name = node.output(0);
830+
for (const auto another_node : model_proto_.graph().node()) {
831+
for (const auto input_name : another_node.input()) {
832+
if (input_name == output_name &&
833+
another_node.op_type() != "Gemm") {
834+
return {false,
835+
"Reshape can only be the last layer or precede a "
836+
"gemm layer for now"};
837+
}
838+
}
839+
}
840+
}
841+
return {true, ""};
842+
}
843+
844+
std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
845+
const ONNX_NAMESPACE::ModelProto &model_proto) {
846+
GOOGLE_PROTOBUF_VERIFY_VERSION;
847+
model_proto_ = model_proto;
848+
HandleInitializer();
849+
850+
std::vector<std::vector<int>> supported_node_vecs;
851+
std::vector<int> supported_node_vec;
852+
for (size_t i = 0; i < model_proto.graph().node_size(); i++) {
853+
bool supported;
854+
std::string error_msg;
855+
std::tie(supported, error_msg) =
856+
IsNodeSupported(model_proto_.graph().node(i));
857+
if (supported) {
858+
supported_node_vec.push_back(i);
859+
} else {
860+
if (!supported_node_vec.empty()) {
861+
supported_node_vecs.push_back(supported_node_vec);
862+
supported_node_vec.clear();
863+
}
864+
}
865+
}
866+
if (!supported_node_vec.empty()) {
867+
supported_node_vecs.push_back(supported_node_vec);
868+
}
869+
Clear();
870+
return supported_node_vecs;
871+
}
872+
754873
void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,
755874
const css &table_file) {
756875
GOOGLE_PROTOBUF_VERIFY_VERSION;
@@ -1032,6 +1151,10 @@ void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,
10321151
LOG(INFO) << "Shapes: ";
10331152
LOG(INFO) << shaper_;
10341153

1154+
Clear();
1155+
}
1156+
1157+
void OnnxConverter::Clear() {
10351158
skipped_act_.clear();
10361159
layers_.clear();
10371160
operands_.clear();

0 commit comments

Comments
 (0)