@@ -613,6 +613,15 @@ Tensor sum_to_size(const Tensor& self, IntArrayRef size) {
613613 return sum_to (self, size);
614614}
615615
616+ // We currently do not support per-channel quant for unfold, diagonal, expand, permute.
617+ // TODO: Make this an aten function and replace as_strided_qtensorimpl once that is done.
618+ Tensor make_qtensor (const Tensor& self, IntArrayRef size, IntArrayRef stride, QuantizerPtr quantizer) {
619+ auto result = detail::make_tensor<QTensorImpl>(
620+ Storage (self.storage ()), self.key_set (), quantizer);
621+ setStrided (result, size, stride, self.storage_offset ());
622+ return result;
623+ }
624+
616625Tensor as_strided_tensorimpl (const Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t > storage_offset_) {
617626 auto storage_offset = storage_offset_.value_or (self.storage_offset ());
618627 auto result = detail::make_tensor<TensorImpl>(Storage (self.storage ()), self.key_set ());
@@ -1232,9 +1241,66 @@ inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) {
12321241 return std::make_tuple (sizes, strides);
12331242}
12341243
1244+ Tensor squeeze_qtensor (const Tensor& self) {
1245+ auto quantizer = get_qtensorimpl (self)->quantizer ();
1246+ std::vector<int64_t > sizes;
1247+ std::vector<int64_t > strides;
1248+ std::tie (sizes, strides) = inferSqueezeGeometry (self);
1249+ if (quantizer->qscheme () == QScheme::PER_CHANNEL_AFFINE) {
1250+ const auto * per_channel_quantizer = static_cast <at::PerChannelAffineQuantizer*>(quantizer.get ());
1251+ auto axis = per_channel_quantizer->axis ();
1252+ int64_t shift = 0 ;
1253+ for (int64_t d = 0 ; d < self.dim (); ++d) {
1254+ if (self.sizes ()[d] == 1 ) {
1255+ TORCH_CHECK (axis != d, " Squeeze is only possible on non-axis dimension for Per-Channel Quantized Tensors." );
1256+ if (d < axis) {
1257+ shift += 1 ;
1258+ }
1259+ }
1260+ }
1261+ axis = axis - shift;
1262+ quantizer = make_per_channel_affine_quantizer (per_channel_quantizer->scales (),
1263+ per_channel_quantizer->zero_points (),
1264+ axis,
1265+ quantizer->scalar_type ());
1266+ }
1267+ return make_qtensor (self, sizes, strides, quantizer);
1268+ }
1269+
1270+ Tensor squeeze_qtensor (const Tensor& self, int64_t dim) {
1271+ auto quantizer = get_qtensorimpl (self)->quantizer ();
1272+ std::vector<int64_t > sizes;
1273+ std::vector<int64_t > strides;
1274+ std::tie (sizes, strides) = inferSqueezeGeometry (self, dim);
1275+ if (quantizer->qscheme () == QScheme::PER_CHANNEL_AFFINE) {
1276+ const auto * per_channel_quantizer = static_cast <at::PerChannelAffineQuantizer*>(quantizer.get ());
1277+ auto axis = per_channel_quantizer->axis ();
1278+ TORCH_CHECK (axis != dim, " Squeeze is only possible on non-axis dimension for Per-Channel Quantized Tensors." );
1279+ if (axis >= dim) {
1280+ axis -= 1 ;
1281+ }
1282+ quantizer = make_per_channel_affine_quantizer (per_channel_quantizer->scales (),
1283+ per_channel_quantizer->zero_points (),
1284+ axis,
1285+ quantizer->scalar_type ());
1286+ }
1287+ if (self.dim () == 0 || self.sizes ()[dim] != 1 ) {
1288+ sizes = self.sizes ().vec ();
1289+ strides = self.strides ().vec ();
1290+ }
1291+ auto result = make_qtensor (self, sizes, strides, quantizer);
1292+ namedinference::propagate_names_except (result, self, {dim});
1293+ return result;
1294+ }
1295+
12351296Tensor squeeze (const Tensor& self) {
12361297 auto g = inferSqueezeGeometry (self);
1237- auto result = self.as_strided (std::get<0 >(g), std::get<1 >(g));
1298+ at::Tensor result;
1299+ if (self.is_quantized ()) {
1300+ result = squeeze_qtensor (self);
1301+ } else {
1302+ result = self.as_strided (std::get<0 >(g), std::get<1 >(g));
1303+ }
12381304 auto maybe_outnames = namedinference::compute_squeeze_outnames (self);
12391305 namedinference::propagate_names_if_nonempty (result, maybe_outnames);
12401306 return result;
@@ -1244,6 +1310,9 @@ Tensor squeeze(const Tensor& self, int64_t dim) {
12441310 int64_t dims = self.dim ();
12451311 dim = maybe_wrap_dim (dim, dims);
12461312
1313+ if (self.is_quantized ()) {
1314+ return squeeze_qtensor (self, dim);
1315+ }
12471316 if (dims == 0 || self.sizes ()[dim] != 1 ) {
12481317 return self.as_strided (self.sizes (), self.strides ());
12491318 }
@@ -1303,11 +1372,31 @@ static Tensor unsqueeze_sparse(Tensor const &self, int64_t dim /* should already
13031372 }
13041373}
13051374
1375+ Tensor unsqueeze_qtensor (const Tensor& self, int64_t dim) {
1376+ dim = maybe_wrap_dim (dim, self.dim () + 1 );
1377+ auto g = inferUnsqueezeGeometry (self, dim);
1378+ auto quantizer = get_qtensorimpl (self)->quantizer ();
1379+ if (quantizer->qscheme () == QScheme::PER_CHANNEL_AFFINE) {
1380+ const auto * per_channel_quantizer = static_cast <at::PerChannelAffineQuantizer*>(quantizer.get ());
1381+ auto axis = per_channel_quantizer->axis ();
1382+ if (axis >= dim) {
1383+ axis += 1 ;
1384+ }
1385+ quantizer = make_per_channel_affine_quantizer (per_channel_quantizer->scales (),
1386+ per_channel_quantizer->zero_points (),
1387+ axis,
1388+ quantizer->scalar_type ());
1389+ }
1390+ return make_qtensor (self, std::get<0 >(g), std::get<1 >(g), quantizer);
1391+ }
1392+
13061393Tensor unsqueeze (const Tensor& self, int64_t dim) {
13071394 dim = maybe_wrap_dim (dim, self.dim () + 1 );
13081395
13091396 if (self.is_sparse ()) {
13101397 return unsqueeze_sparse (self, dim);
1398+ } else if (self.is_quantized ()) {
1399+ return unsqueeze_qtensor (self, dim);
13111400 } else {
13121401 auto g = inferUnsqueezeGeometry (self, dim);
13131402 return self.as_strided (std::get<0 >(g), std::get<1 >(g));
0 commit comments