Skip to content

Commit 4a96911

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode] quantization support for aten::chunk (#34806)
Summary: Pull Request resolved: #34806 Test Plan: python test/test_jit.py Imported from OSS Differential Revision: D20524454 fbshipit-source-id: 92ac9bc251581e963258cb90dc3de73f8508c822
1 parent 9c8f09d commit 4a96911

2 files changed

Lines changed: 6 additions & 0 deletions

File tree

test/test_jit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2155,6 +2155,7 @@ def forward(self, x):
21552155
x = torch.min(x)
21562156
x = torch.mean(x)
21572157
x = x.reshape([-1])
2158+
x, y = torch.chunk(x, 2)
21582159
x = F.dropout(x)
21592160
x = self.dropout(x)
21602161
# TODO: uncomment when sort is supported
@@ -2177,6 +2178,8 @@ def forward(self, x):
21772178
.check("aten::min") \
21782179
.check("aten::mean") \
21792180
.check("aten::reshape") \
2181+
.check("aten::chunk") \
2182+
.check("prim::ListUnpack") \
21802183
.check("aten::dropout") \
21812184
.check("aten::dropout") \
21822185
.run(m.graph)
@@ -2189,6 +2192,8 @@ def forward(self, x):
21892192
.check("aten::min") \
21902193
.check("aten::mean") \
21912194
.check("aten::reshape") \
2195+
.check("aten::chunk") \
2196+
.check("prim::ListUnpack") \
21922197
.check("aten::dropout") \
21932198
.check("aten::dropout") \
21942199
.check("dequantize") \

torch/csrc/jit/passes/quantization.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ std::vector<size_t> getGeneralOpTensorInputIndexes(Node* n) {
136136
"upsample_bicubic2d",
137137
"dropout",
138138
"reshape",
139+
"chunk",
139140
// TODO: sort returns a tuple of Tensors, we have
140141
// to extend the API to support that
141142
// "sort",

0 commit comments

Comments
 (0)