Skip to content

Commit b3c0939

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode][refactor] Move the whitelists to a centeralized place (#35721)
Summary: Pull Request resolved: #35721 Test Plan: . Imported from OSS Differential Revision: D20771829 fbshipit-source-id: f6ec3afe2d8034acbdbd81e5a6fbd4a2a76aa7ac
1 parent e372f42 commit b3c0939

1 file changed

Lines changed: 47 additions & 41 deletions

File tree

torch/csrc/jit/passes/quantization.cpp

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ struct PatternsAndModules {
6565
Module packed_params_module;
6666
};
6767

68+
std::vector<std::string> _quantizable_call_funcs = {
69+
"conv2d",
70+
"linear",
71+
};
72+
73+
std::vector<std::string> _quantizable_aten_funcs = {
74+
"conv2d",
75+
"conv3d",
76+
"linear",
77+
"addmm",
78+
"matmul",
79+
"add_",
80+
"add",
81+
"cat",
82+
};
83+
6884
// These are the prim::CallFunctions that doesn't require observation and
6985
// have a single input Tensor
7086
// example: `prim::CallFunction(%dropout, %input_tensor, ...)
@@ -81,9 +97,35 @@ std::vector<std::string> _single_input_general_call_funcs = {
8197
"relu",
8298
};
8399

84-
std::vector<std::string> _quantizable_call_funcs = {
85-
"conv2d",
86-
"linear",
100+
// Similar to prim::CallFunctions, there are aten ops that doesn't
101+
// require observation and have a single input Tensor
102+
// e.g. `aten::max_pool2d(%input_tensor, ...)`
103+
std::vector<std::string> _single_input_general_aten_funcs = {
104+
"max_pool2d",
105+
"avg_pool2d",
106+
"flatten",
107+
"max",
108+
"min",
109+
"mean",
110+
"upsample_nearest1d",
111+
"upsample_nearest2d",
112+
"upsample_nearest3d",
113+
"adaptive_avg_pool1d",
114+
"adaptive_avg_pool2d",
115+
"adaptive_avg_pool3d",
116+
"upsample_linear1d",
117+
"upsample_bilinear2d",
118+
"upsample_trilinear3d",
119+
"upsample_bicubic2d",
120+
"dropout",
121+
"reshape",
122+
"chunk",
123+
"view",
124+
"transpose",
125+
"contiguous",
126+
"permute",
127+
"repeat_interleave",
128+
"relu",
87129
};
88130

89131
void fillQConfigMap(
@@ -167,33 +209,6 @@ bool isAddScalar(Node* n) {
167209
// the quantization parameters for `v` given the list of values
168210
std::vector<Value*> getPassThroughInputs(Value* v) {
169211
Node* n = v->node();
170-
std::vector<std::string> single_input_aten_funcs = {
171-
"max_pool2d",
172-
"avg_pool2d",
173-
"flatten",
174-
"max",
175-
"min",
176-
"mean",
177-
"upsample_nearest1d",
178-
"upsample_nearest2d",
179-
"upsample_nearest3d",
180-
"adaptive_avg_pool1d",
181-
"adaptive_avg_pool2d",
182-
"adaptive_avg_pool3d",
183-
"upsample_linear1d",
184-
"upsample_bilinear2d",
185-
"upsample_trilinear3d",
186-
"upsample_bicubic2d",
187-
"dropout",
188-
"reshape",
189-
"chunk",
190-
"view",
191-
"transpose",
192-
"contiguous",
193-
"permute",
194-
"repeat_interleave",
195-
"relu",
196-
};
197212
if (isFunctionNode(
198213
n,
199214
// We don't have call functions
@@ -207,7 +222,7 @@ std::vector<Value*> getPassThroughInputs(Value* v) {
207222
// We don't have call functions
208223
// after inline
209224
/* call_funcs = */ {},
210-
/* aten_funcs = */ single_input_aten_funcs) ||
225+
/* aten_funcs = */ _single_input_general_aten_funcs) ||
211226
(n->kind() == Symbol::aten("sort") && v->offset() == 0)) {
212227
return {n->input(0)};
213228
} else if (n->kind() == prim::If && n->outputs().size() == 1) {
@@ -242,16 +257,7 @@ bool nodeQuantizable(Node* n) {
242257
/* call_funcs = */
243258
_quantizable_call_funcs,
244259
/* aten_funcs = */
245-
{
246-
"conv2d",
247-
"conv3d",
248-
"linear",
249-
"addmm",
250-
"matmul",
251-
"add_",
252-
"add",
253-
"cat",
254-
});
260+
_quantizable_aten_funcs);
255261
}
256262

257263
// We don't want to analyze the graph for some `builtin` CallFunctions

0 commit comments

Comments
 (0)