Skip to content

Commit c672a73

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode][refactor] getGeneralOpTensorInputIndexes -> getGeneralOpTensorInputs (#35141)
Summary: Pull Request resolved: #35141 This is preparing for the support of prim::If in SwapDeQuant Test Plan: . Imported from OSS Differential Revision: D20655300 fbshipit-source-id: 0c66cab37f3f46dd34217a7b99a4d25a159c8487
1 parent 26b2725 commit c672a73

1 file changed

Lines changed: 17 additions & 18 deletions

File tree

torch/csrc/jit/passes/quantization.cpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@ bool isFunctionNode(
132132
}
133133

134134
// If the op doesn't require observation, return
135-
// the the list of input indexes that we should check to see
135+
// the the list of input `Value`s that we should check to see
136136
// if they are observed/quantized, if so, we can say the output
137137
// of this op is observed/quantized as well, since for these ops we can derive
138138
// the quantization parameters for output given inputs
139-
std::vector<size_t> getGeneralOpTensorInputIndexes(Node* n) {
139+
std::vector<Value*> getGeneralOpTensorInputs(Node* n) {
140140
std::vector<std::string> single_input_aten_funcs = {
141141
"max_pool2d",
142142
"avg_pool2d",
@@ -170,22 +170,22 @@ std::vector<size_t> getGeneralOpTensorInputIndexes(Node* n) {
170170
// after inline
171171
/* call_funcs = */ _single_input_general_call_funcs,
172172
/* aten_funcs = */ {})) {
173-
return {1};
173+
return {n->input(1)};
174174
} else if (isFunctionNode(
175175
n,
176176
// We don't have call functions
177177
// after inline
178178
/* call_funcs = */ {},
179179
/* aten_funcs = */ single_input_aten_funcs)) {
180-
return {0};
180+
return {n->input(0)};
181181
} else if (n->kind() == prim::ListUnpack) {
182-
return {0};
182+
return {n->input(0)};
183183
} else if (n->kind() == prim::ListConstruct) {
184-
std::vector<size_t> indexes;
185-
for (auto i = 0; i < n->inputs().size(); ++i) {
186-
indexes.push_back(i);
184+
std::vector<Value*> inputs;
185+
for (auto* v : n->inputs()) {
186+
inputs.push_back(v);
187187
}
188-
return indexes;
188+
return inputs;
189189
}
190190
return {};
191191
}
@@ -814,10 +814,10 @@ void InsertObserversHelper::fillPassThroughValueMap(
814814
auto g = getCallFunctionGraph(n);
815815
blocks_to_visit.push(g->block());
816816
}
817-
auto input_indexes = getGeneralOpTensorInputIndexes(n);
818-
for (auto i : input_indexes) {
817+
auto inputs = getGeneralOpTensorInputs(n);
818+
for (auto* input : inputs) {
819819
for (auto* output : n->outputs()) {
820-
pass_through_value_map_[output].push_back(n->input(i));
820+
pass_through_value_map_[output].push_back(input);
821821
}
822822
}
823823
for (Block* subblock : n->blocks()) {
@@ -2224,19 +2224,18 @@ void addBiasForConv2dIfNone(Module& module) {
22242224
void swapDeQuant(Block* block) {
22252225
auto graph = block->owningGraph();
22262226
for (Node* n : block->nodes()) {
2227-
auto input_indexes = getGeneralOpTensorInputIndexes(n);
2228-
if (input_indexes.size() > 0) {
2227+
auto inputs = getGeneralOpTensorInputs(n);
2228+
if (inputs.size() > 0) {
22292229
bool is_dequantized = true;
2230-
for (auto i : input_indexes) {
2231-
is_dequantized &= n->inputs()[i]->node()->kind() == Symbol::aten("dequantize");
2230+
for (auto* input : inputs) {
2231+
is_dequantized &= input->node()->kind() == Symbol::aten("dequantize");
22322232
}
22332233
if (!is_dequantized) {
22342234
continue;
22352235
}
22362236
// Delete dequantize node, we have one dequantize
22372237
// for each use of the value
2238-
for (auto i : input_indexes) {
2239-
auto* dequantized_val = n->inputs()[i];
2238+
for (auto* dequantized_val : inputs) {
22402239
auto* dequantize_node = dequantized_val->node();
22412240
TORCH_INTERNAL_ASSERT(dequantized_val->uses().size() == 1,
22422241
"Expect to have one dequantize node for each use");

0 commit comments

Comments
 (0)