@@ -325,99 +325,6 @@ def insert_observer_for_input_arg_of_observed_node(
325325 activation_post_process_indexes ,
326326 env , observed_graph , load_arg , observed_node_names_set , quants )
327327
328- def insert_observers_for_model (
329- model : torch .nn .Module ,
330- modules : Dict [str , torch .nn .Module ],
331- matches : Dict [str , MatchResult ],
332- quants : Dict [str , List [Tuple [DefaultQuantizeHandler , Callable ]]],
333- observed_node_names_set : Set [str ],
334- qconfig_map : Dict [str , QConfigAny ],
335- activation_post_process_map : Dict [str , List [str ]],
336- activation_post_process_indexes : Dict [str , int ],
337- observed_graph : Graph ,
338- prepare_custom_config_dict : Dict [str , Any ],
339- input_quantized_idxs : List [int ],
340- output_quantized_idxs : List [int ]) -> Optional [Node ]:
341- env : Dict [Any , Any ] = {}
342-
343- def load_arg (a ):
344- return map_arg (a , lambda node : env [node .name ])
345-
346- graph_inputs = [node .name for node in model .graph .nodes if node .op == "placeholder" ]
347- get_new_observer_name = get_new_attr_name_with_prefix (
348- 'activation_post_process_' )
349- placeholder_node_seen_cnt = 0
350- output_node_seen_cnt = 0
351- result_node : Optional [Node ] = None
352- for node in model .graph .nodes :
353- if node .op == 'output' :
354- # If this output is hardcoded to be quantized, insert an
355- # observer on the previous node if it does not already
356- # exist.
357- cur_output_node_idx = output_node_seen_cnt
358- output_node_seen_cnt += 1
359- if cur_output_node_idx in output_quantized_idxs :
360- prev_node = node .args [0 ]
361- assert isinstance (prev_node , Node ), \
362- ('hardcoding list/dict outputs to be quantized is ' +
363- 'not supported' )
364- if prev_node .name not in observed_node_names_set :
365- assert qconfig_map is not None
366- local_qconfig = qconfig_map [prev_node .name ]
367- assert local_qconfig is not None , \
368- 'qconfig of a node before a quantized output must exist'
369- insert_observer (
370- prev_node , local_qconfig .activation (),
371- model ,
372- activation_post_process_map ,
373- activation_post_process_indexes ,
374- env , observed_graph , load_arg , observed_node_names_set , quants )
375-
376- observed_graph .output (load_arg (node .args [0 ]))
377- result_node = node
378- continue
379-
380- if node .name in observed_node_names_set :
381- continue
382-
383- root_node , matched_nodes , pattern , obj , qconfig = matches .get (
384- node .name , (None , None , None , None , None ))
385- env [node .name ] = observed_graph .node_copy (node , load_arg )
386- if root_node is node :
387- # index for input of custom module that needs to be observed in
388- # parent
389- if qconfig is not None :
390- assert obj is not None
391- standalone_module_input_idxs = \
392- maybe_insert_observer_for_special_module (
393- obj , modules , prepare_custom_config_dict , qconfig ,
394- node )
395- insert_observer_for_output_of_the_node (
396- node , obj , qconfig , modules , model , pattern ,
397- activation_post_process_map ,
398- activation_post_process_indexes ,
399- env ,
400- observed_graph , load_arg , observed_node_names_set ,
401- matched_nodes , standalone_module_input_idxs , quants )
402-
403- if node .op == 'placeholder' :
404- # skip adding observers at the graph input if the input is
405- # overriden to be quantized
406- cur_placeholder_node_idx = placeholder_node_seen_cnt
407- placeholder_node_seen_cnt += 1
408- if cur_placeholder_node_idx in input_quantized_idxs :
409- observed_node_names_set .add (node .name )
410- continue
411-
412- insert_observer_for_input_arg_of_observed_node (
413- node , observed_node_names_set , quants ,
414- model , activation_post_process_map ,
415- activation_post_process_indexes ,
416- env ,
417- observed_graph , load_arg )
418-
419- return result_node
420-
421328def handle_copy_nodes (
422329 observed_graph : Graph , matches : Dict [str , MatchResult ],
423330 quants : Dict [str , List [Tuple [DefaultQuantizeHandler , Callable ]]],
@@ -692,20 +599,98 @@ def _prepare(
692599 self ._find_quants (model .graph , self .modules , matches )
693600
694601 self .activation_post_process_map = defaultdict (list )
602+ env : Dict [Any , Any ] = {}
695603 observed_graph = Graph ()
696604 observed_node_names_set : Set [str ] = set ()
697605
606+ def load_arg (a ):
607+ return map_arg (a , lambda node : env [node .name ])
608+
609+ graph_inputs = []
610+ for node in model .graph .nodes :
611+ if node .op == 'placeholder' :
612+ graph_inputs .append (node .name )
613+
614+ get_new_observer_name = get_new_attr_name_with_prefix (
615+ 'activation_post_process_' )
616+
617+ placeholder_node_seen_cnt = 0
618+ output_node_seen_cnt = 0
698619 input_quantized_idxs : List [int ] = self .prepare_custom_config_dict .get (
699620 "input_quantized_idxs" , [])
700621 output_quantized_idxs : List [int ] = self .prepare_custom_config_dict .get (
701622 "output_quantized_idxs" , [])
702623
703- result_node = insert_observers_for_model (
704- model , self .modules , matches , quants , observed_node_names_set ,
705- self .qconfig_map , self .activation_post_process_map , self .activation_post_process_indexes ,
706- observed_graph , prepare_custom_config_dict , input_quantized_idxs , output_quantized_idxs )
624+ result_node : Optional [Node ] = None
625+ for node in model .graph .nodes :
626+ if node .op == 'output' :
627+ # If this output is hardcoded to be quantized, insert an
628+ # observer on the previous node if it does not already
629+ # exist.
630+ cur_output_node_idx = output_node_seen_cnt
631+ output_node_seen_cnt += 1
632+ if cur_output_node_idx in output_quantized_idxs :
633+ prev_node = node .args [0 ]
634+ assert isinstance (prev_node , Node ), \
635+ ('hardcoding list/dict outputs to be quantized is ' +
636+ 'not supported' )
637+ if prev_node .name not in observed_node_names_set :
638+ assert self .qconfig_map is not None
639+ local_qconfig = self .qconfig_map [prev_node .name ]
640+ assert local_qconfig is not None , \
641+ 'qconfig of a node before a quantized output must exist'
642+ insert_observer (
643+ prev_node , local_qconfig .activation (),
644+ model ,
645+ self .activation_post_process_map ,
646+ self .activation_post_process_indexes ,
647+ env , observed_graph , load_arg , observed_node_names_set , quants )
648+
649+ observed_graph .output (load_arg (node .args [0 ]))
650+ result_node = node
651+ continue
652+
653+ if node .name in observed_node_names_set :
654+ continue
655+
656+ root_node , matched_nodes , pattern , obj , qconfig = matches .get (
657+ node .name , (None , None , None , None , None ))
658+ env [node .name ] = observed_graph .node_copy (node , load_arg )
659+ if root_node is node :
660+ # index for input of custom module that needs to be observed in
661+ # parent
662+ if qconfig is not None :
663+ assert obj is not None
664+ standalone_module_input_idxs = \
665+ maybe_insert_observer_for_special_module (
666+ obj , self .modules , prepare_custom_config_dict , qconfig ,
667+ node )
668+ insert_observer_for_output_of_the_node (
669+ node , obj , qconfig , self .modules , model , pattern ,
670+ self .activation_post_process_map ,
671+ self .activation_post_process_indexes ,
672+ env ,
673+ observed_graph , load_arg , observed_node_names_set ,
674+ matched_nodes , standalone_module_input_idxs , quants )
675+
676+ if node .op == 'placeholder' :
677+ # skip adding observers at the graph input if the input is
678+ # overriden to be quantized
679+ cur_placeholder_node_idx = placeholder_node_seen_cnt
680+ placeholder_node_seen_cnt += 1
681+ if cur_placeholder_node_idx in input_quantized_idxs :
682+ observed_node_names_set .add (node .name )
683+ continue
684+
685+ insert_observer_for_input_arg_of_observed_node (
686+ node , observed_node_names_set , quants ,
687+ model , self .activation_post_process_map ,
688+ self .activation_post_process_indexes ,
689+ env ,
690+ observed_graph , load_arg )
707691
708692 self .modules = dict (model .named_modules ())
693+
709694 # TODO: refactor this to a separate function
710695 matches = self ._find_matches (
711696 observed_graph , self .modules , self .patterns , standalone_module_names ,
0 commit comments