Skip to content

Commit f2bae8e

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][fix] at::print for per channel affine quantized tensors (#36280)
Summary: Pull Request resolved: #36280 Test Plan: . Imported from OSS Differential Revision: D20948352 fbshipit-source-id: 92188806b9c129458ebb2cdc47599427e3b6e216
1 parent 51456dc commit f2bae8e

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

aten/src/ATen/core/Formatting.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,18 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
278278
}
279279
if (tensor_.is_quantized()) {
280280
stream << ", qscheme: " << toString(tensor_.qscheme());
281-
stream << ", scale: " << tensor_.q_scale();
282-
stream << ", zero_point: " << tensor_.q_zero_point();
281+
if (tensor_.qscheme() == c10::kPerTensorAffine) {
282+
stream << ", scale: " << tensor_.q_scale();
283+
stream << ", zero_point: " << tensor_.q_zero_point();
284+
} else if (tensor_.qscheme() == c10::kPerChannelAffine) {
285+
stream << ", scales: ";
286+
Tensor scales = tensor_.q_per_channel_scales();
287+
print(stream, scales, linesize);
288+
stream << ", zero_points: ";
289+
Tensor zero_points = tensor_.q_per_channel_zero_points();
290+
print(stream, zero_points, linesize);
291+
stream << ", axis: " << tensor_.q_per_channel_axis();
292+
}
283293
}
284294
stream << " ]";
285295
}

0 commit comments

Comments
 (0)