Skip to content

Commit 15f04e3

Browse files
Mike Ruberryfacebook-github-bot
authored andcommitted
Revert D27408378: [quant][graphmode][fx][refactor] Factor out insert_observers_for_model to a separate function
Test Plan: revert-hammer Differential Revision: D27408378 (c445f4e) Original commit changeset: 9143f0a6f939 fbshipit-source-id: ae65ea798a6d72f2ec724c4c1b492937edddf721
1 parent 8b8c409 commit 15f04e3

1 file changed

Lines changed: 82 additions & 97 deletions

File tree

torch/quantization/fx/quantize.py

Lines changed: 82 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -325,99 +325,6 @@ def insert_observer_for_input_arg_of_observed_node(
325325
activation_post_process_indexes,
326326
env, observed_graph, load_arg, observed_node_names_set, quants)
327327

328-
def insert_observers_for_model(
329-
model: torch.nn.Module,
330-
modules: Dict[str, torch.nn.Module],
331-
matches: Dict[str, MatchResult],
332-
quants: Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]],
333-
observed_node_names_set: Set[str],
334-
qconfig_map: Dict[str, QConfigAny],
335-
activation_post_process_map: Dict[str, List[str]],
336-
activation_post_process_indexes: Dict[str, int],
337-
observed_graph: Graph,
338-
prepare_custom_config_dict: Dict[str, Any],
339-
input_quantized_idxs: List[int],
340-
output_quantized_idxs: List[int]) -> Optional[Node]:
341-
env: Dict[Any, Any] = {}
342-
343-
def load_arg(a):
344-
return map_arg(a, lambda node: env[node.name])
345-
346-
graph_inputs = [node.name for node in model.graph.nodes if node.op == "placeholder"]
347-
get_new_observer_name = get_new_attr_name_with_prefix(
348-
'activation_post_process_')
349-
placeholder_node_seen_cnt = 0
350-
output_node_seen_cnt = 0
351-
result_node: Optional[Node] = None
352-
for node in model.graph.nodes:
353-
if node.op == 'output':
354-
# If this output is hardcoded to be quantized, insert an
355-
# observer on the previous node if it does not already
356-
# exist.
357-
cur_output_node_idx = output_node_seen_cnt
358-
output_node_seen_cnt += 1
359-
if cur_output_node_idx in output_quantized_idxs:
360-
prev_node = node.args[0]
361-
assert isinstance(prev_node, Node), \
362-
('hardcoding list/dict outputs to be quantized is ' +
363-
'not supported')
364-
if prev_node.name not in observed_node_names_set:
365-
assert qconfig_map is not None
366-
local_qconfig = qconfig_map[prev_node.name]
367-
assert local_qconfig is not None, \
368-
'qconfig of a node before a quantized output must exist'
369-
insert_observer(
370-
prev_node, local_qconfig.activation(),
371-
model,
372-
activation_post_process_map,
373-
activation_post_process_indexes,
374-
env, observed_graph, load_arg, observed_node_names_set, quants)
375-
376-
observed_graph.output(load_arg(node.args[0]))
377-
result_node = node
378-
continue
379-
380-
if node.name in observed_node_names_set:
381-
continue
382-
383-
root_node, matched_nodes, pattern, obj, qconfig = matches.get(
384-
node.name, (None, None, None, None, None))
385-
env[node.name] = observed_graph.node_copy(node, load_arg)
386-
if root_node is node:
387-
# index for input of custom module that needs to be observed in
388-
# parent
389-
if qconfig is not None:
390-
assert obj is not None
391-
standalone_module_input_idxs = \
392-
maybe_insert_observer_for_special_module(
393-
obj, modules, prepare_custom_config_dict, qconfig,
394-
node)
395-
insert_observer_for_output_of_the_node(
396-
node, obj, qconfig, modules, model, pattern,
397-
activation_post_process_map,
398-
activation_post_process_indexes,
399-
env,
400-
observed_graph, load_arg, observed_node_names_set,
401-
matched_nodes, standalone_module_input_idxs, quants)
402-
403-
if node.op == 'placeholder':
404-
# skip adding observers at the graph input if the input is
405-
# overriden to be quantized
406-
cur_placeholder_node_idx = placeholder_node_seen_cnt
407-
placeholder_node_seen_cnt += 1
408-
if cur_placeholder_node_idx in input_quantized_idxs:
409-
observed_node_names_set.add(node.name)
410-
continue
411-
412-
insert_observer_for_input_arg_of_observed_node(
413-
node, observed_node_names_set, quants,
414-
model, activation_post_process_map,
415-
activation_post_process_indexes,
416-
env,
417-
observed_graph, load_arg)
418-
419-
return result_node
420-
421328
def handle_copy_nodes(
422329
observed_graph: Graph, matches: Dict[str, MatchResult],
423330
quants: Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]],
@@ -692,20 +599,98 @@ def _prepare(
692599
self._find_quants(model.graph, self.modules, matches)
693600

694601
self.activation_post_process_map = defaultdict(list)
602+
env: Dict[Any, Any] = {}
695603
observed_graph = Graph()
696604
observed_node_names_set: Set[str] = set()
697605

606+
def load_arg(a):
607+
return map_arg(a, lambda node: env[node.name])
608+
609+
graph_inputs = []
610+
for node in model.graph.nodes:
611+
if node.op == 'placeholder':
612+
graph_inputs.append(node.name)
613+
614+
get_new_observer_name = get_new_attr_name_with_prefix(
615+
'activation_post_process_')
616+
617+
placeholder_node_seen_cnt = 0
618+
output_node_seen_cnt = 0
698619
input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
699620
"input_quantized_idxs", [])
700621
output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
701622
"output_quantized_idxs", [])
702623

703-
result_node = insert_observers_for_model(
704-
model, self.modules, matches, quants, observed_node_names_set,
705-
self.qconfig_map, self.activation_post_process_map, self.activation_post_process_indexes,
706-
observed_graph, prepare_custom_config_dict, input_quantized_idxs, output_quantized_idxs)
624+
result_node : Optional[Node] = None
625+
for node in model.graph.nodes:
626+
if node.op == 'output':
627+
# If this output is hardcoded to be quantized, insert an
628+
# observer on the previous node if it does not already
629+
# exist.
630+
cur_output_node_idx = output_node_seen_cnt
631+
output_node_seen_cnt += 1
632+
if cur_output_node_idx in output_quantized_idxs:
633+
prev_node = node.args[0]
634+
assert isinstance(prev_node, Node), \
635+
('hardcoding list/dict outputs to be quantized is ' +
636+
'not supported')
637+
if prev_node.name not in observed_node_names_set:
638+
assert self.qconfig_map is not None
639+
local_qconfig = self.qconfig_map[prev_node.name]
640+
assert local_qconfig is not None, \
641+
'qconfig of a node before a quantized output must exist'
642+
insert_observer(
643+
prev_node, local_qconfig.activation(),
644+
model,
645+
self.activation_post_process_map,
646+
self.activation_post_process_indexes,
647+
env, observed_graph, load_arg, observed_node_names_set, quants)
648+
649+
observed_graph.output(load_arg(node.args[0]))
650+
result_node = node
651+
continue
652+
653+
if node.name in observed_node_names_set:
654+
continue
655+
656+
root_node, matched_nodes, pattern, obj, qconfig = matches.get(
657+
node.name, (None, None, None, None, None))
658+
env[node.name] = observed_graph.node_copy(node, load_arg)
659+
if root_node is node:
660+
# index for input of custom module that needs to be observed in
661+
# parent
662+
if qconfig is not None:
663+
assert obj is not None
664+
standalone_module_input_idxs = \
665+
maybe_insert_observer_for_special_module(
666+
obj, self.modules, prepare_custom_config_dict, qconfig,
667+
node)
668+
insert_observer_for_output_of_the_node(
669+
node, obj, qconfig, self.modules, model, pattern,
670+
self.activation_post_process_map,
671+
self.activation_post_process_indexes,
672+
env,
673+
observed_graph, load_arg, observed_node_names_set,
674+
matched_nodes, standalone_module_input_idxs, quants)
675+
676+
if node.op == 'placeholder':
677+
# skip adding observers at the graph input if the input is
678+
# overriden to be quantized
679+
cur_placeholder_node_idx = placeholder_node_seen_cnt
680+
placeholder_node_seen_cnt += 1
681+
if cur_placeholder_node_idx in input_quantized_idxs:
682+
observed_node_names_set.add(node.name)
683+
continue
684+
685+
insert_observer_for_input_arg_of_observed_node(
686+
node, observed_node_names_set, quants,
687+
model, self.activation_post_process_map,
688+
self.activation_post_process_indexes,
689+
env,
690+
observed_graph, load_arg)
707691

708692
self.modules = dict(model.named_modules())
693+
709694
# TODO: refactor this to a separate function
710695
matches = self._find_matches(
711696
observed_graph, self.modules, self.patterns, standalone_module_names,

0 commit comments

Comments
 (0)