@@ -61,68 +61,6 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cpu(
6161 return std::make_tuple (std::move (dX), std::move (dgamma), std::move (dbeta));
6262}
6363
64- std::tuple<Tensor, Tensor, Tensor, int64_t , int64_t > _prepare_layer_norm_inputs (
65- const Tensor& input,
66- IntArrayRef normalized_shape,
67- const Tensor& weight /* optional */ ,
68- const Tensor& bias /* optional */ ) {
69-
70- const int normalized_ndim = normalized_shape.size ();
71- TORCH_CHECK (
72- normalized_ndim >= 1 ,
73- " Expected normalized_shape to be at least 1-dimensional, i.e., " ,
74- " containing at least one element, but got normalized_shape = " ,
75- normalized_shape);
76- TORCH_CHECK (
77- !weight.defined () || weight.sizes ().equals (normalized_shape),
78- " Expected weight to be of same shape as normalized_shape, but got " ,
79- " weight of shape " ,
80- weight.sizes (),
81- " and normalized_shape = " ,
82- normalized_shape);
83- TORCH_CHECK (
84- !bias.defined () || bias.sizes ().equals (normalized_shape),
85- " Expected bias to be of same shape as normalized_shape, but got " ,
86- " bias of shape " ,
87- bias.sizes (),
88- " and normalized_shape = " ,
89- normalized_shape);
90-
91- const auto input_shape = input.sizes ();
92- const auto input_ndim = input.dim ();
93-
94- if (input_ndim < normalized_ndim ||
95- !input_shape.slice (input_ndim - normalized_ndim)
96- .equals (normalized_shape)) {
97- std::stringstream ss;
98- ss << " Given normalized_shape=" << normalized_shape
99- << " , expected input with shape [*" ;
100- for (auto size : normalized_shape) {
101- ss << " , " << size;
102- }
103- ss << " ], but got input of size" << input_shape;
104- AT_ERROR (ss.str ());
105- }
106-
107- const int axis = input_ndim - normalized_ndim;
108- const int64_t M = std::accumulate (
109- input_shape.cbegin (),
110- input_shape.cbegin () + axis,
111- 1LL ,
112- std::multiplies<int64_t >());
113- const int64_t N = std::accumulate (
114- input_shape.cbegin () + axis,
115- input_shape.cend (),
116- 1LL ,
117- std::multiplies<int64_t >());
118-
119- const auto & X = input.is_contiguous () ? input : input.contiguous ();
120- const auto & gamma = weight.is_contiguous () ? weight : weight.contiguous ();
121- const auto & beta = bias.is_contiguous () ? bias : bias.contiguous ();
122-
123- return std::make_tuple (X, gamma, beta, M, N);
124- }
125-
12664Tensor layer_norm (
12765 const Tensor& input,
12866 IntArrayRef normalized_shape,
@@ -141,148 +79,8 @@ Tensor layer_norm(
14179 return std::get<0 >(at::native_layer_norm (X, gamma, beta, M, N, eps));
14280}
14381
144- Tensor quantized_layer_norm_impl (
145- const Tensor& input,
146- IntArrayRef normalized_shape,
147- const Tensor& weight /* optional */ ,
148- const Tensor& bias /* optional */ ,
149- double eps,
150- double output_scale,
151- int64_t output_zero_point) {
152-
153- auto inputs = _prepare_layer_norm_inputs (input, normalized_shape, weight, bias);
154- auto X = std::get<0 >(inputs);
155- auto gamma = std::get<1 >(inputs);
156- auto beta = std::get<2 >(inputs);
157- auto M = std::get<3 >(inputs);
158- auto N = std::get<4 >(inputs);
159-
160- Tensor Y = at::_empty_affine_quantized (
161- X.sizes (),
162- X.scalar_type (),
163- output_scale,
164- output_zero_point,
165- X.suggest_memory_format ());
166-
167- if (M > 0 ) {
168- bool affine_per_channel = false ;
169- int num_channels = 1 ; // not relevant for LayerNorm
170- int num_groups = 1 ; // not relevant for LayerNorm
171- quantized_normalize_stub (kCPU , X, gamma, beta, affine_per_channel,
172- num_channels, num_groups, M, N, eps, &Y);
173- }
174- return Y;
175- }
176-
177- Tensor quantized_group_norm_impl (
178- const Tensor& qx,
179- int64_t num_groups,
180- const Tensor& weight, // optional
181- const Tensor& bias, // optional
182- double eps,
183- double output_scale,
184- int64_t output_zero_point) {
185-
186- const auto input_ndim = qx.dim ();
187- TORCH_CHECK (
188- input_ndim >= 3 ,
189- " Expected normalized_shape to be at least 3-dimensional" );
190- TORCH_CHECK (num_groups > 0 , " Expected num_groups to be positive" );
191-
192- const auto input_shape = qx.sizes ();
193- TORCH_CHECK (input_shape[1 ] % num_groups == 0 ,
194- " Expected channels to be divisible by groups" );
195-
196- const int64_t batches = input_shape[0 ];
197- const int64_t num_channels = input_shape[1 ];
198- const int64_t elements_per_batch = std::accumulate (
199- input_shape.cbegin () + 1 ,
200- input_shape.cend (),
201- 1LL ,
202- std::multiplies<int64_t >());
203-
204- const int64_t M = batches * num_groups;
205- const int64_t N = elements_per_batch / num_groups;
206-
207- const auto & qx_contig = qx.is_contiguous () ? qx : qx.contiguous ();
208- const auto & weight_contig = weight.is_contiguous () ? weight : weight.contiguous ();
209- const auto & bias_contig = bias.is_contiguous () ? bias : bias.contiguous ();
210-
211- Tensor Y = at::_empty_affine_quantized (
212- qx.sizes (),
213- qx.scalar_type (),
214- output_scale,
215- output_zero_point,
216- qx.suggest_memory_format ());
217-
218- if (M > 0 ) {
219- bool affine_per_channel = true ;
220- quantized_normalize_stub (kCPU , qx_contig, weight_contig, bias_contig,
221- affine_per_channel, num_channels, num_groups, M, N, eps, &Y);
222- }
223- return Y;
224- }
225-
226- Tensor quantized_instance_norm_impl (
227- const Tensor& qx,
228- const Tensor& weight, // optional
229- const Tensor& bias, // optional
230- double eps,
231- double output_scale,
232- int64_t output_zero_point) {
233-
234- const auto input_ndim = qx.dim ();
235- TORCH_CHECK (
236- input_ndim >= 3 ,
237- " Expected normalized_shape to be at least 3-dimensional" );
238- const auto input_shape = qx.sizes ();
239-
240- // IN is GN with num_groups == num_channels
241- const auto num_channels = input_shape[1 ];
242- TORCH_CHECK (num_channels > 0 , " Expected 2nd dimension to be positive" );
243-
244- return quantized_group_norm_impl (
245- qx, num_channels, weight, bias, eps, output_scale, output_zero_point);
246- }
247-
248- TORCH_LIBRARY_IMPL (quantized, QuantizedCPU, m) {
249- // TODO: this is kind of... blegh
250- m.impl (" layer_norm" , [](
251- Tensor input,
252- std::vector<int64_t > normalized_shape, // because IntArrayRef doesn't work
253- Tensor weight /* optional */ ,
254- Tensor bias /* optional */ ,
255- double eps,
256- double output_scale,
257- int64_t output_zero_point) {
258- return quantized_layer_norm_impl (input, normalized_shape, weight, bias, eps, output_scale, output_zero_point);
259- });
260- m.impl (" group_norm" , [](
261- Tensor qx,
262- int64_t num_groups,
263- Tensor weight,
264- Tensor bias,
265- double eps,
266- double output_scale,
267- int64_t output_zero_point) {
268- return quantized_group_norm_impl (
269- qx, num_groups, weight, bias, eps, output_scale, output_zero_point);
270- });
271- m.impl (" instance_norm" , [](
272- Tensor qx,
273- Tensor weight,
274- Tensor bias,
275- double eps,
276- double output_scale,
277- int64_t output_zero_point) {
278- return quantized_instance_norm_impl (
279- qx, weight, bias, eps, output_scale, output_zero_point);
280- });
281- }
282-
28382DEFINE_DISPATCH (LayerNormKernel);
28483DEFINE_DISPATCH (LayerNormBackwardKernel);
285- DEFINE_DISPATCH (quantized_normalize_stub);
28684
28785} // namespace native
28886} // namespace at
0 commit comments