Skip to content

Commit e90c32f

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode][refactor] Support filter function in quant fusion patterns (#35333)
Summary: Pull Request resolved: #35333 Test Plan: regression tests in: python test/test_jit.py Imported from OSS Differential Revision: D20655312 fbshipit-source-id: 50b937bc56aff93f20fe9a0079bf3aec50f6d25d
1 parent 5557ceb commit e90c32f

2 files changed

Lines changed: 28 additions & 17 deletions

File tree

torch/csrc/jit/passes/quantization.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2473,10 +2473,10 @@ void SwapDeQuant(std::shared_ptr<Graph>& graph) {
24732473
}
24742474

24752475
void QuantFusion(std::shared_ptr<Graph>& graph) {
2476-
for (const auto& item : quant_fusion_pattern_and_replacements()) {
2476+
for (const auto& info : quant_fusion_pattern_and_replacements()) {
24772477
SubgraphRewriter rewriter;
2478-
rewriter.RegisterRewritePattern(item.first, item.second);
2479-
rewriter.runOnGraph(graph);
2478+
rewriter.RegisterRewritePattern(info.pattern, info.replacement);
2479+
rewriter.runOnGraph(graph, info.filter);
24802480
}
24812481
}
24822482

torch/csrc/jit/passes/quantization_patterns.h

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,23 @@
22

33
#include <string>
44
#include <unordered_map>
5+
#include <torch/csrc/jit/ir/ir.h>
6+
#include <torch/csrc/jit/ir/subgraph_matcher.h>
7+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
58

69
namespace torch {
710
namespace jit {
811

9-
std::unordered_map<std::string, std::string>
10-
quant_fusion_pattern_and_replacements() {
12+
struct QuantFusionInfo {
13+
std::string quantized_op_name;
14+
std::string pattern;
15+
std::string replacement;
16+
std::function<bool(const Match&, const std::unordered_map<std::string, Value*>&)> filter = [](const Match&, const std::unordered_map<std::string, Value*>&) {
17+
return true;
18+
};
19+
};
20+
21+
std::vector<QuantFusionInfo> quant_fusion_pattern_and_replacements() {
1122
std::string conv2d = R"(
1223
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
1324
%a_dequant = aten::dequantize(%a_quant)
@@ -165,18 +176,18 @@ graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
165176
// We don't have quantized inplace add right now
166177

167178
return {
168-
{conv2d, quantized_conv2d},
169-
{conv2d_relu, quantized_conv2d_relu},
170-
{conv2d_inplace_relu, quantized_conv2d_relu},
171-
{addmm, quantized_linear},
172-
{matmul_with_bias, quantized_linear},
173-
{matmul_no_bias, quantized_linear_no_bias},
174-
{aten_linear, quantized_aten_linear},
175-
{add_relu, quantized_add_relu},
176-
{add_inplace_relu, quantized_add_relu},
177-
{add, quantized_add},
178-
{inplace_add, quantized_add},
179-
{cat, quantized_cat},
179+
{"quantized::conv2d", conv2d, quantized_conv2d},
180+
{"quantized::conv2d_relu", conv2d_relu, quantized_conv2d_relu},
181+
{"quantized::conv2d_relu", conv2d_inplace_relu, quantized_conv2d_relu},
182+
{"quantized::linear", addmm, quantized_linear},
183+
{"quantized::linear", matmul_with_bias, quantized_linear},
184+
{"quantized::linear", matmul_no_bias, quantized_linear_no_bias},
185+
{"quantized::linear", aten_linear, quantized_aten_linear},
186+
{"quantized::add_relu", add_relu, quantized_add_relu},
187+
{"quantized::add_relu", add_inplace_relu, quantized_add_relu},
188+
{"quantized::add", add, quantized_add},
189+
{"quantized::add", inplace_add, quantized_add},
190+
{"quantized::cat", cat, quantized_cat},
180191
};
181192
}
182193

0 commit comments

Comments
 (0)