88import io
99import os
1010import re
11+ import glob
12+ import subprocess
1113from textwrap import dedent
1214
1315autogen_header = """\
@@ -106,6 +108,13 @@ def convert_to_proto3(lines): # type: (Iterable[Text]) -> Iterable[Text]
106108 yield line
107109
108110
111+ def gen_proto3_code (protoc_path , proto3_path , include_path , cpp_out , python_out ): # type: (Text, Text, Text, Text, Text) -> None
112+ print ("Generate pb3 code using {}" .format (protoc_path ))
113+ build_args = [protoc_path , proto3_path , '-I' , include_path ]
114+ build_args .extend (['--cpp_out' , cpp_out , '--python_out' , python_out ])
115+ subprocess .check_call (build_args )
116+
117+
109118def translate (source , proto , onnx_ml , package_name ): # type: (Text, int, bool, Text) -> Text
110119 lines = source .splitlines () # type: Iterable[Text]
111120 lines = process_ifs (lines , onnx_ml = onnx_ml )
@@ -121,7 +130,7 @@ def qualify(f, pardir=os.path.realpath(os.path.dirname(__file__))): # type: (Te
121130 return os .path .join (pardir , f )
122131
123132
124- def convert (stem , package_name , output , do_onnx_ml = False , lite = False ): # type: (Text, Text, Text, bool, bool) -> None
133+ def convert (stem , package_name , output , do_onnx_ml = False , lite = False , protoc_path = '' ): # type: (Text, Text, Text, bool, bool, Text ) -> None
125134 proto_in = qualify ("{}.in.proto" .format (stem ))
126135 need_rename = (package_name != DEFAULT_PACKAGE_NAME )
127136 if do_onnx_ml :
@@ -146,6 +155,15 @@ def convert(stem, package_name, output, do_onnx_ml=False, lite=False): # type:
146155 fout .write (translate (source , proto = 3 , onnx_ml = do_onnx_ml , package_name = package_name ))
147156 if lite :
148157 fout .write (LITE_OPTION )
158+ if protoc_path :
159+ porto3_dir = os .path .dirname (proto3 )
160+ base_dir = os .path .dirname (porto3_dir )
161+ gen_proto3_code (protoc_path , proto3 , base_dir , base_dir , base_dir )
162+ pb3_files = glob .glob (os .path .join (porto3_dir , '*.proto3.*' ))
163+ for pb3_file in pb3_files :
164+ print ("Removing {}" .format (pb3_file ))
165+ os .remove (pb3_file )
166+
149167 if need_rename :
150168 if do_onnx_ml :
151169 proto_header = qualify ("{}-ml.pb.h" .format (stem ), pardir = output )
@@ -193,6 +211,9 @@ def main(): # type: () -> None
193211 parser .add_argument ('-o' , '--output' ,
194212 default = os .path .realpath (os .path .dirname (__file__ )),
195213 help = 'output directory (default: %(default)s)' )
214+ parser .add_argument ('--protoc_path' ,
215+ default = '' ,
216+ help = 'path to protoc for proto3 file validation' )
196217 parser .add_argument ('stems' , nargs = '*' , default = ['onnx' , 'onnx-operators' ],
197218 help = 'list of .in.proto file stems '
198219 '(default: %(default)s)' )
@@ -206,7 +227,8 @@ def main(): # type: () -> None
206227 package_name = args .package ,
207228 output = args .output ,
208229 do_onnx_ml = args .ml ,
209- lite = args .lite )
230+ lite = args .lite ,
231+ protoc_path = args .protoc_path )
210232
211233
212234if __name__ == '__main__' :
0 commit comments