@@ -65,6 +65,22 @@ struct PatternsAndModules {
6565 Module packed_params_module;
6666};
6767
68+ std::vector<std::string> _quantizable_call_funcs = {
69+ " conv2d" ,
70+ " linear" ,
71+ };
72+
73+ std::vector<std::string> _quantizable_aten_funcs = {
74+ " conv2d" ,
75+ " conv3d" ,
76+ " linear" ,
77+ " addmm" ,
78+ " matmul" ,
79+ " add_" ,
80+ " add" ,
81+ " cat" ,
82+ };
83+
6884// These are the prim::CallFunctions that doesn't require observation and
6985// have a single input Tensor
7086// example: `prim::CallFunction(%dropout, %input_tensor, ...)
@@ -81,9 +97,35 @@ std::vector<std::string> _single_input_general_call_funcs = {
8197 " relu" ,
8298};
8399
84- std::vector<std::string> _quantizable_call_funcs = {
85- " conv2d" ,
86- " linear" ,
100+ // Similar to prim::CallFunctions, there are aten ops that doesn't
101+ // require observation and have a single input Tensor
102+ // e.g. `aten::max_pool2d(%input_tensor, ...)`
103+ std::vector<std::string> _single_input_general_aten_funcs = {
104+ " max_pool2d" ,
105+ " avg_pool2d" ,
106+ " flatten" ,
107+ " max" ,
108+ " min" ,
109+ " mean" ,
110+ " upsample_nearest1d" ,
111+ " upsample_nearest2d" ,
112+ " upsample_nearest3d" ,
113+ " adaptive_avg_pool1d" ,
114+ " adaptive_avg_pool2d" ,
115+ " adaptive_avg_pool3d" ,
116+ " upsample_linear1d" ,
117+ " upsample_bilinear2d" ,
118+ " upsample_trilinear3d" ,
119+ " upsample_bicubic2d" ,
120+ " dropout" ,
121+ " reshape" ,
122+ " chunk" ,
123+ " view" ,
124+ " transpose" ,
125+ " contiguous" ,
126+ " permute" ,
127+ " repeat_interleave" ,
128+ " relu" ,
87129};
88130
89131void fillQConfigMap (
@@ -167,33 +209,6 @@ bool isAddScalar(Node* n) {
167209// the quantization parameters for `v` given the list of values
168210std::vector<Value*> getPassThroughInputs (Value* v) {
169211 Node* n = v->node ();
170- std::vector<std::string> single_input_aten_funcs = {
171- " max_pool2d" ,
172- " avg_pool2d" ,
173- " flatten" ,
174- " max" ,
175- " min" ,
176- " mean" ,
177- " upsample_nearest1d" ,
178- " upsample_nearest2d" ,
179- " upsample_nearest3d" ,
180- " adaptive_avg_pool1d" ,
181- " adaptive_avg_pool2d" ,
182- " adaptive_avg_pool3d" ,
183- " upsample_linear1d" ,
184- " upsample_bilinear2d" ,
185- " upsample_trilinear3d" ,
186- " upsample_bicubic2d" ,
187- " dropout" ,
188- " reshape" ,
189- " chunk" ,
190- " view" ,
191- " transpose" ,
192- " contiguous" ,
193- " permute" ,
194- " repeat_interleave" ,
195- " relu" ,
196- };
197212 if (isFunctionNode (
198213 n,
199214 // We don't have call functions
@@ -207,7 +222,7 @@ std::vector<Value*> getPassThroughInputs(Value* v) {
207222 // We don't have call functions
208223 // after inline
209224 /* call_funcs = */ {},
210- /* aten_funcs = */ single_input_aten_funcs ) ||
225+ /* aten_funcs = */ _single_input_general_aten_funcs ) ||
211226 (n->kind () == Symbol::aten (" sort" ) && v->offset () == 0 )) {
212227 return {n->input (0 )};
213228 } else if (n->kind () == prim::If && n->outputs ().size () == 1 ) {
@@ -242,16 +257,7 @@ bool nodeQuantizable(Node* n) {
242257 /* call_funcs = */
243258 _quantizable_call_funcs,
244259 /* aten_funcs = */
245- {
246- " conv2d" ,
247- " conv3d" ,
248- " linear" ,
249- " addmm" ,
250- " matmul" ,
251- " add_" ,
252- " add" ,
253- " cat" ,
254- });
260+ _quantizable_aten_funcs);
255261}
256262
257263// We don't want to analyze the graph for some `builtin` CallFunctions
0 commit comments