@@ -132,11 +132,11 @@ bool isFunctionNode(
132132}
133133
134134// If the op doesn't require observation, return
135- // the the list of input indexes that we should check to see
135+ // the the list of input `Value`s that we should check to see
136136// if they are observed/quantized, if so, we can say the output
137137// of this op is observed/quantized as well, since for these ops we can derive
138138// the quantization parameters for output given inputs
139- std::vector<size_t > getGeneralOpTensorInputIndexes (Node* n) {
139+ std::vector<Value*> getGeneralOpTensorInputs (Node* n) {
140140 std::vector<std::string> single_input_aten_funcs = {
141141 " max_pool2d" ,
142142 " avg_pool2d" ,
@@ -170,22 +170,22 @@ std::vector<size_t> getGeneralOpTensorInputIndexes(Node* n) {
170170 // after inline
171171 /* call_funcs = */ _single_input_general_call_funcs,
172172 /* aten_funcs = */ {})) {
173- return {1 };
173+ return {n-> input ( 1 ) };
174174 } else if (isFunctionNode (
175175 n,
176176 // We don't have call functions
177177 // after inline
178178 /* call_funcs = */ {},
179179 /* aten_funcs = */ single_input_aten_funcs)) {
180- return {0 };
180+ return {n-> input ( 0 ) };
181181 } else if (n->kind () == prim::ListUnpack) {
182- return {0 };
182+ return {n-> input ( 0 ) };
183183 } else if (n->kind () == prim::ListConstruct) {
184- std::vector<size_t > indexes ;
185- for (auto i = 0 ; i < n->inputs (). size (); ++i ) {
186- indexes .push_back (i );
184+ std::vector<Value*> inputs ;
185+ for (auto * v : n->inputs ()) {
186+ inputs .push_back (v );
187187 }
188- return indexes ;
188+ return inputs ;
189189 }
190190 return {};
191191}
@@ -814,10 +814,10 @@ void InsertObserversHelper::fillPassThroughValueMap(
814814 auto g = getCallFunctionGraph (n);
815815 blocks_to_visit.push (g->block ());
816816 }
817- auto input_indexes = getGeneralOpTensorInputIndexes (n);
818- for (auto i : input_indexes ) {
817+ auto inputs = getGeneralOpTensorInputs (n);
818+ for (auto * input : inputs ) {
819819 for (auto * output : n->outputs ()) {
820- pass_through_value_map_[output].push_back (n-> input (i) );
820+ pass_through_value_map_[output].push_back (input);
821821 }
822822 }
823823 for (Block* subblock : n->blocks ()) {
@@ -2224,19 +2224,18 @@ void addBiasForConv2dIfNone(Module& module) {
22242224void swapDeQuant (Block* block) {
22252225 auto graph = block->owningGraph ();
22262226 for (Node* n : block->nodes ()) {
2227- auto input_indexes = getGeneralOpTensorInputIndexes (n);
2228- if (input_indexes .size () > 0 ) {
2227+ auto inputs = getGeneralOpTensorInputs (n);
2228+ if (inputs .size () > 0 ) {
22292229 bool is_dequantized = true ;
2230- for (auto i : input_indexes ) {
2231- is_dequantized &= n-> inputs ()[i] ->node ()->kind () == Symbol::aten (" dequantize" );
2230+ for (auto * input : inputs ) {
2231+ is_dequantized &= input ->node ()->kind () == Symbol::aten (" dequantize" );
22322232 }
22332233 if (!is_dequantized) {
22342234 continue ;
22352235 }
22362236 // Delete dequantize node, we have one dequantize
22372237 // for each use of the value
2238- for (auto i : input_indexes) {
2239- auto * dequantized_val = n->inputs ()[i];
2238+ for (auto * dequantized_val : inputs) {
22402239 auto * dequantize_node = dequantized_val->node ();
22412240 TORCH_INTERNAL_ASSERT (dequantized_val->uses ().size () == 1 ,
22422241 " Expect to have one dequantize node for each use" );
0 commit comments