@@ -17,9 +17,9 @@ using std::vector;
1717using Shape = Shaper::Shape;
1818
1919namespace dnn {
20- std::string OnnxConverter::m (const std::string &str) {
20+ std::string OnnxConverter::m (const std::string &str) const {
2121 if (name_map_.find (str) != name_map_.end ()) {
22- return name_map_[ str] ;
22+ return name_map_. at ( str) ;
2323 }
2424
2525 return str;
@@ -751,6 +751,125 @@ void OnnxConverter::Convert(const std::string &model_str,
751751 Save (filepath);
752752}
753753
754+ std::pair<bool , std::string> OnnxConverter::IsNodeSupported (
755+ const ONNX_NAMESPACE::NodeProto &node) const {
756+ NodeAttrHelper helper (node);
757+ const auto &op = node.op_type ();
758+ const std::vector<std::string> supported_types{
759+ " Conv" , " AveragePool" ,
760+ " MaxPool" , " GlobalAveragePool" ,
761+ " GlobalMaxPool" , " Relu" ,
762+ " PRelu" , " Add" ,
763+ " Mul" , " Gemm" ,
764+ " Softmax" , " Concat" ,
765+ " Dropout" , " BatchNormalization" ,
766+ " Reshape" , " LRN" ,
767+ " Identity" };
768+ if (std::find (supported_types.begin (), supported_types.end (), op) ==
769+ supported_types.end ()) {
770+ return {false , " Unsupported operator " + op};
771+ }
772+ if (op == " Conv" ) {
773+ const auto strides = helper.get (" strides" , vector<int >{1 , 1 });
774+ const auto pads = helper.get (" pads" , vector<int >{0 , 0 , 0 , 0 });
775+ const auto dilations = helper.get (" dilations" , vector<int >{1 , 1 });
776+ CHECK_EQ (pads.size (), 4ul );
777+ CHECK_EQ (strides.size (), 2ul );
778+ CHECK_EQ (dilations.size (), 2ul );
779+ const auto group = helper.get (" group" , 1 );
780+ if (dilations != vector<int >{1 , 1 } && strides != vector<int >{1 , 1 }) {
781+ return {false ,
782+ " Both dilations and strides > 1 is not supported for now" };
783+ }
784+ const auto weight_name = m (node.input (1 ));
785+ const auto &onnx_weight = onnx_tensors_.at (weight_name);
786+ if (group != 1 && onnx_weight.shape [1 ] != 1 ) {
787+ return {false , " group != 1 is not supported" };
788+ }
789+ } else if (op == " AveragePool" || op == " MaxPool" ) {
790+ const auto count_include_pad = helper.get (" count_include_pad" , 0 );
791+ if (count_include_pad == 1 ) {
792+ return {false , " count_include_pad == 1 is not supported" };
793+ }
794+ const auto storage_order = helper.get (" storage_order" , 0 );
795+ if (storage_order == 1 ) {
796+ return {false , " storage_order == 1 is not supported" };
797+ }
798+ if (helper.get (" auto_pad" , " NOTSET" ) != " NOTSET" ) {
799+ return {false , " auto_pad is not supported" };
800+ }
801+ } else if (op == " PRelu" ) {
802+ const auto slope_name = m (node.input (1 ));
803+ if (onnx_tensors_.at (slope_name).shape != Shape{1 }) {
804+ // TODO: support it
805+ return {false , " Only support one element slope." };
806+ }
807+ } else if (op == " Gemm" ) {
808+ const auto transA = helper.get (" transA" , 0 );
809+ const auto transB = helper.get (" transB" , 0 );
810+ const auto alpha = helper.get (" alpha" , 1 .0f );
811+ const auto beta = helper.get (" beta" , 1 .0f );
812+ if (!(transA == 0 && transB == 1 && alpha == 1 .f && beta == 1 .f )) {
813+ return {false ,
814+ " Only transA == 0, transB == 1, alpha == 1.0 and beta == "
815+ " 1.0 is supported." };
816+ }
817+ } else if (op == " BatchNormalization" ) {
818+ if (node.output_size () != 1 ) {
819+ return {false ,
820+ " Your onnx model may be in training mode, please export "
821+ " it in test mode." };
822+ }
823+ } else if (op == " LRN" ) {
824+ const auto size = helper.get (" size" , 1 );
825+ if (size % 2 == 0 ) {
826+ return {false , " NNAPI only support odd size for LRN" };
827+ }
828+ } else if (op == " Reshape" ) {
829+ const auto output_name = node.output (0 );
830+ for (const auto another_node : model_proto_.graph ().node ()) {
831+ for (const auto input_name : another_node.input ()) {
832+ if (input_name == output_name &&
833+ another_node.op_type () != " Gemm" ) {
834+ return {false ,
835+ " Reshape can only be the last layer or precede a "
836+ " gemm layer for now" };
837+ }
838+ }
839+ }
840+ }
841+ return {true , " " };
842+ }
843+
844+ std::vector<std::vector<int >> OnnxConverter::GetSupportedNodes (
845+ const ONNX_NAMESPACE::ModelProto &model_proto) {
846+ GOOGLE_PROTOBUF_VERIFY_VERSION;
847+ model_proto_ = model_proto;
848+ HandleInitializer ();
849+
850+ std::vector<std::vector<int >> supported_node_vecs;
851+ std::vector<int > supported_node_vec;
852+ for (size_t i = 0 ; i < model_proto.graph ().node_size (); i++) {
853+ bool supported;
854+ std::string error_msg;
855+ std::tie (supported, error_msg) =
856+ IsNodeSupported (model_proto_.graph ().node (i));
857+ if (supported) {
858+ supported_node_vec.push_back (i);
859+ } else {
860+ if (!supported_node_vec.empty ()) {
861+ supported_node_vecs.push_back (supported_node_vec);
862+ supported_node_vec.clear ();
863+ }
864+ }
865+ }
866+ if (!supported_node_vec.empty ()) {
867+ supported_node_vecs.push_back (supported_node_vec);
868+ }
869+ Clear ();
870+ return supported_node_vecs;
871+ }
872+
754873void OnnxConverter::Convert (const ONNX_NAMESPACE::ModelProto &model_proto,
755874 const css &table_file) {
756875 GOOGLE_PROTOBUF_VERIFY_VERSION;
@@ -1032,6 +1151,10 @@ void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,
10321151 LOG (INFO) << " Shapes: " ;
10331152 LOG (INFO) << shaper_;
10341153
1154+ Clear ();
1155+ }
1156+
1157+ void OnnxConverter::Clear () {
10351158 skipped_act_.clear ();
10361159 layers_.clear ();
10371160 operands_.clear ();
0 commit comments