Skip to content

Commit 53424ae

Browse files
committed
[quant][graph] Add useQuantizable function
Summary: Enables to selectively insert observers at the inputs of aten/call functionc Test Plan: test_quantize_script.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 2cf5312 commit 53424ae

1 file changed

Lines changed: 52 additions & 29 deletions

File tree

torch/csrc/jit/passes/quantization.cpp

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ std::vector<std::string> _quantizable_aten_funcs = {
6161
"add_",
6262
"add",
6363
"cat",
64+
"lstm",
6465
};
6566

6667
// These are the prim::CallFunctions that doesn't require observation and
@@ -121,6 +122,20 @@ std::vector<std::string> _single_input_general_aten_funcs = {
121122
"relu",
122123
};
123124

125+
struct FuncArg {
126+
std::string func_name;
127+
int arg_index;
128+
};
129+
130+
using AtenFuncArgs = std::vector<FuncArg>;
131+
using CallFuncArgs = std::vector<FuncArg>;
132+
133+
// Special checks for ops that do not require observers for all input tensors.
134+
// For each operator in this list observers are inserted for the input based
135+
// on the index specified.
136+
AtenFuncArgs _observe_inputs_aten_func = {};
137+
CallFuncArgs _observe_inputs_call_func = {};
138+
124139
void fillQConfigMap(
125140
const Module& module,
126141
const QConfigDict& qconfig_dict,
@@ -686,33 +701,28 @@ graph(%self, %a, %b, %inplace):
686701
};
687702

688703
// Check if `use` is an aten function of name `func_name` and if value
689-
// `v` is the nth argument of the function
690-
bool isAtenFuncNthArg(
691-
Value* v,
692-
Node* use,
704+
// `v` is the nth argument (if provided) of the function.
705+
bool matchAtenFuncToUse(
706+
const Use& use,
693707
const std::string& func_name,
694-
int n) {
695-
return use->kind() == Symbol::aten(func_name) && v == use->inputs().at(n);
708+
c10::optional<int> n) {
709+
Node* node = use.user;
710+
return node->kind() == Symbol::aten(func_name) &&
711+
(!n.has_value() || n.value() == use.offset);
696712
}
697713

698714
// Check if `use` is a CallFunction of name `func_name` and if value
699-
// `v` is the nth argument of the function
700-
bool isCallFunctionNthArg(
701-
Value* v,
702-
Node* use,
715+
// `v` is the nth argument (if provided) of the function
716+
bool matchCallFuncToUse(
717+
const Use& use,
703718
const std::string& func_name,
704-
int n) {
705-
return use->kind() == prim::CallFunction &&
706-
getFuncName(use->inputs()[0]) == func_name && v == use->inputs().at(n);
719+
c10::optional<int> n) {
720+
Node* node = use.user;
721+
return node->kind() == prim::CallFunction &&
722+
getFuncName(node->inputs()[0]) == func_name &&
723+
(!n.has_value() || n.value() == use.offset);
707724
}
708725

709-
struct FuncArg {
710-
std::string func_name;
711-
int arg_index;
712-
};
713-
using AtenFuncArgs = std::vector<FuncArg>;
714-
using CallFuncArgs = std::vector<FuncArg>;
715-
716726
// Check any use of `v` matches the aten function call
717727
// or CallFunction patterns
718728
bool matchArgPattern(
@@ -721,14 +731,13 @@ bool matchArgPattern(
721731
const CallFuncArgs& call_func_args) {
722732
for (const Use& u : v->uses()) {
723733
for (const auto& func_arg : aten_func_args) {
724-
if (isAtenFuncNthArg(v, u.user, func_arg.func_name, func_arg.arg_index)) {
734+
if (matchAtenFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
725735
return true;
726736
}
727737
}
728738

729739
for (const auto& func_arg : call_func_args) {
730-
if (isCallFunctionNthArg(
731-
v, u.user, func_arg.func_name, func_arg.arg_index)) {
740+
if (matchCallFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
732741
return true;
733742
}
734743
}
@@ -997,10 +1006,24 @@ void InsertObserversHelper::preprocess(
9971006
}
9981007
}
9991008

1000-
// Returns true if the value is the weight to LSTM operator.
1001-
bool isDynamicLSTMWeight(Value* v, Use use, bool is_dynamic) {
1002-
return is_dynamic && use.user->kind() == Symbol::aten("lstm") &&
1003-
(use.offset == 2);
1009+
bool useQuantizable(const Use& use, bool is_dynamic) {
1010+
for (const auto& func_input : _observe_inputs_aten_func) {
1011+
if (matchAtenFuncToUse(use, func_input.func_name, c10::nullopt)) {
1012+
return use.offset == func_input.arg_index;
1013+
}
1014+
}
1015+
1016+
for (const auto& func_input : _observe_inputs_call_func) {
1017+
if (matchCallFuncToUse(use, func_input.func_name, c10::nullopt)) {
1018+
return use.offset == func_input.arg_index;
1019+
}
1020+
}
1021+
// Dynamic quantized ops that require special handling for inputs.
1022+
if (is_dynamic && matchAtenFuncToUse(use, "lstm", c10::nullopt)) {
1023+
return use.offset == 2;
1024+
}
1025+
1026+
return nodeQuantizable(use.user);
10041027
}
10051028

10061029
// TODO: remove this as a class method
@@ -1018,9 +1041,9 @@ bool InsertObserversHelper::valueNeedsToBeQuantized(Value* v) {
10181041
return true;
10191042
}
10201043
}
1021-
// Check whether user is quantizable
1044+
// Check whether node input value is quantizable
10221045
for (const auto& use : v->uses()) {
1023-
if (nodeQuantizable(use.user) || isDynamicLSTMWeight(v, use, is_dynamic)) {
1046+
if (useQuantizable(use, is_dynamic)) {
10241047
return true;
10251048
}
10261049
}

0 commit comments

Comments
 (0)