@@ -61,6 +61,7 @@ std::vector<std::string> _quantizable_aten_funcs = {
6161 " add_" ,
6262 " add" ,
6363 " cat" ,
64+ " lstm" ,
6465};
6566
6667// These are the prim::CallFunctions that doesn't require observation and
@@ -121,6 +122,20 @@ std::vector<std::string> _single_input_general_aten_funcs = {
121122 " relu" ,
122123};
123124
125+ struct FuncArg {
126+ std::string func_name;
127+ int arg_index;
128+ };
129+
130+ using AtenFuncArgs = std::vector<FuncArg>;
131+ using CallFuncArgs = std::vector<FuncArg>;
132+
133+ // Special checks for ops that do not require observers for all input tensors.
134+ // For each operator in this list observers are inserted for the input based
135+ // on the index specified.
136+ AtenFuncArgs _observe_inputs_aten_func = {};
137+ CallFuncArgs _observe_inputs_call_func = {};
138+
124139void fillQConfigMap (
125140 const Module& module ,
126141 const QConfigDict& qconfig_dict,
@@ -686,33 +701,28 @@ graph(%self, %a, %b, %inplace):
686701};
687702
688703// Check if `use` is an aten function of name `func_name` and if value
689- // `v` is the nth argument of the function
690- bool isAtenFuncNthArg (
691- Value* v,
692- Node* use,
704+ // `v` is the nth argument (if provided) of the function.
705+ bool matchAtenFuncToUse (
706+ const Use& use,
693707 const std::string& func_name,
694- int n) {
695- return use->kind () == Symbol::aten (func_name) && v == use->inputs ().at (n);
708+ c10::optional<int > n) {
709+ Node* node = use.user ;
710+ return node->kind () == Symbol::aten (func_name) &&
711+ (!n.has_value () || n.value () == use.offset );
696712}
697713
698714// Check if `use` is a CallFunction of name `func_name` and if value
699- // `v` is the nth argument of the function
700- bool isCallFunctionNthArg (
701- Value* v,
702- Node* use,
715+ // `v` is the nth argument (if provided) of the function
716+ bool matchCallFuncToUse (
717+ const Use& use,
703718 const std::string& func_name,
704- int n) {
705- return use->kind () == prim::CallFunction &&
706- getFuncName (use->inputs ()[0 ]) == func_name && v == use->inputs ().at (n);
719+ c10::optional<int > n) {
720+ Node* node = use.user ;
721+ return node->kind () == prim::CallFunction &&
722+ getFuncName (node->inputs ()[0 ]) == func_name &&
723+ (!n.has_value () || n.value () == use.offset );
707724}
708725
709- struct FuncArg {
710- std::string func_name;
711- int arg_index;
712- };
713- using AtenFuncArgs = std::vector<FuncArg>;
714- using CallFuncArgs = std::vector<FuncArg>;
715-
716726// Check any use of `v` matches the aten function call
717727// or CallFunction patterns
718728bool matchArgPattern (
@@ -721,14 +731,13 @@ bool matchArgPattern(
721731 const CallFuncArgs& call_func_args) {
722732 for (const Use& u : v->uses ()) {
723733 for (const auto & func_arg : aten_func_args) {
724- if (isAtenFuncNthArg (v, u. user , func_arg.func_name , func_arg.arg_index )) {
734+ if (matchAtenFuncToUse (u , func_arg.func_name , func_arg.arg_index )) {
725735 return true ;
726736 }
727737 }
728738
729739 for (const auto & func_arg : call_func_args) {
730- if (isCallFunctionNthArg (
731- v, u.user , func_arg.func_name , func_arg.arg_index )) {
740+ if (matchCallFuncToUse (u, func_arg.func_name , func_arg.arg_index )) {
732741 return true ;
733742 }
734743 }
@@ -997,10 +1006,24 @@ void InsertObserversHelper::preprocess(
9971006 }
9981007}
9991008
1000- // Returns true if the value is the weight to LSTM operator.
1001- bool isDynamicLSTMWeight (Value* v, Use use, bool is_dynamic) {
1002- return is_dynamic && use.user ->kind () == Symbol::aten (" lstm" ) &&
1003- (use.offset == 2 );
1009+ bool useQuantizable (const Use& use, bool is_dynamic) {
1010+ for (const auto & func_input : _observe_inputs_aten_func) {
1011+ if (matchAtenFuncToUse (use, func_input.func_name , c10::nullopt )) {
1012+ return use.offset == func_input.arg_index ;
1013+ }
1014+ }
1015+
1016+ for (const auto & func_input : _observe_inputs_call_func) {
1017+ if (matchCallFuncToUse (use, func_input.func_name , c10::nullopt )) {
1018+ return use.offset == func_input.arg_index ;
1019+ }
1020+ }
1021+ // Dynamic quantized ops that require special handling for inputs.
1022+ if (is_dynamic && matchAtenFuncToUse (use, " lstm" , c10::nullopt )) {
1023+ return use.offset == 2 ;
1024+ }
1025+
1026+ return nodeQuantizable (use.user );
10041027}
10051028
10061029// TODO: remove this as a class method
@@ -1018,9 +1041,9 @@ bool InsertObserversHelper::valueNeedsToBeQuantized(Value* v) {
10181041 return true ;
10191042 }
10201043 }
1021- // Check whether user is quantizable
1044+ // Check whether node input value is quantizable
10221045 for (const auto & use : v->uses ()) {
1023- if (nodeQuantizable (use. user ) || isDynamicLSTMWeight (v, use, is_dynamic)) {
1046+ if (useQuantizable ( use, is_dynamic)) {
10241047 return true ;
10251048 }
10261049 }
0 commit comments