|
2 | 2 |
|
3 | 3 | #include <string> |
4 | 4 | #include <unordered_map> |
| 5 | +#include <torch/csrc/jit/ir/ir.h> |
| 6 | +#include <torch/csrc/jit/ir/subgraph_matcher.h> |
| 7 | +#include <torch/csrc/jit/passes/subgraph_rewrite.h> |
5 | 8 |
|
6 | 9 | namespace torch { |
7 | 10 | namespace jit { |
8 | 11 |
|
9 | | -std::unordered_map<std::string, std::string> |
10 | | -quant_fusion_pattern_and_replacements() { |
| 12 | +struct QuantFusionInfo { |
| 13 | + std::string quantized_op_name; |
| 14 | + std::string pattern; |
| 15 | + std::string replacement; |
| 16 | + std::function<bool(const Match&, const std::unordered_map<std::string, Value*>&)> filter = [](const Match&, const std::unordered_map<std::string, Value*>&) { |
| 17 | + return true; |
| 18 | + }; |
| 19 | +}; |
| 20 | + |
| 21 | +std::vector<QuantFusionInfo> quant_fusion_pattern_and_replacements() { |
11 | 22 | std::string conv2d = R"( |
12 | 23 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
13 | 24 | %a_dequant = aten::dequantize(%a_quant) |
@@ -165,18 +176,18 @@ graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): |
165 | 176 | // We don't have quantized inplace add right now |
166 | 177 |
|
167 | 178 | return { |
168 | | - {conv2d, quantized_conv2d}, |
169 | | - {conv2d_relu, quantized_conv2d_relu}, |
170 | | - {conv2d_inplace_relu, quantized_conv2d_relu}, |
171 | | - {addmm, quantized_linear}, |
172 | | - {matmul_with_bias, quantized_linear}, |
173 | | - {matmul_no_bias, quantized_linear_no_bias}, |
174 | | - {aten_linear, quantized_aten_linear}, |
175 | | - {add_relu, quantized_add_relu}, |
176 | | - {add_inplace_relu, quantized_add_relu}, |
177 | | - {add, quantized_add}, |
178 | | - {inplace_add, quantized_add}, |
179 | | - {cat, quantized_cat}, |
| 179 | + {"quantized::conv2d", conv2d, quantized_conv2d}, |
| 180 | + {"quantized::conv2d_relu", conv2d_relu, quantized_conv2d_relu}, |
| 181 | + {"quantized::conv2d_relu", conv2d_inplace_relu, quantized_conv2d_relu}, |
| 182 | + {"quantized::linear", addmm, quantized_linear}, |
| 183 | + {"quantized::linear", matmul_with_bias, quantized_linear}, |
| 184 | + {"quantized::linear", matmul_no_bias, quantized_linear_no_bias}, |
| 185 | + {"quantized::linear", aten_linear, quantized_aten_linear}, |
| 186 | + {"quantized::add_relu", add_relu, quantized_add_relu}, |
| 187 | + {"quantized::add_relu", add_inplace_relu, quantized_add_relu}, |
| 188 | + {"quantized::add", add, quantized_add}, |
| 189 | + {"quantized::add", inplace_add, quantized_add}, |
| 190 | + {"quantized::cat", cat, quantized_cat}, |
180 | 191 | }; |
181 | 192 | } |
182 | 193 |
|
|
0 commit comments