@@ -202,6 +202,8 @@ DEFINE_DISPATCH(silu_stub);
202202DEFINE_DISPATCH (silu_backward_stub);
203203DEFINE_DISPATCH (mish_stub);
204204DEFINE_DISPATCH (mish_backward_stub);
205+ DEFINE_DISPATCH (prelu_cpu_stub);
206+ DEFINE_DISPATCH (prelu_backward_cpu_stub);
205207
206208TORCH_IMPL_FUNC (elu_out) (
207209 const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result
@@ -595,253 +597,119 @@ TORCH_IMPL_FUNC(threshold_backward_out)(const Tensor& grad, const Tensor& self,
595597 threshold_stub (device_type (), *this , threshold, 0 );
596598}
597599
598- // -----------------------------------
599- // prelu forward
600- // -----------------------------------
601- template <typename scalar_t >
602- void inline prelu_cpu_kernel_share_weights (
603- Tensor& result,
604- const Tensor& input,
605- const Tensor& weight) {
606-
607- int64_t input_numel = input.numel ();
608- auto result_data = result.data_ptr <scalar_t >();
609- auto input_data = input.data_ptr <scalar_t >();
610- auto weight_val = weight.data_ptr <scalar_t >()[0 ];
611-
612- at::parallel_for (0 , input_numel, 1000 , [&](int64_t start, int64_t end) {
613- for (const auto i : c10::irange (start, end)) {
614- scalar_t input_data_val = input_data[i];
615- // to allow for compiler optimization, here splitting into two lines:
616- scalar_t r = (input_data_val > 0 ) ? scalar_t (1 ) : weight_val;
617- result_data[i] = r * input_data_val;
618- }
619- });
620- }
621-
622- template <typename scalar_t >
623- void inline prelu_cpu_kernel_multi_weights (
624- Tensor& result,
625- const Tensor& input,
626- const Tensor& weight,
627- int64_t input_dim0_size,
628- int64_t channel_size,
629- int64_t input_stride0,
630- int64_t input_stride1) {
631-
632- scalar_t * result_data = result.data_ptr <scalar_t >();
633- scalar_t * input_data = input.data_ptr <scalar_t >();
634- scalar_t * weight_data = weight.data_ptr <scalar_t >();
635-
636- auto loop = [&](int64_t start, int64_t end) {
637- for (const auto i : c10::irange (start, end)) {
638- int64_t offset = i * channel_size * input_stride1;
639- scalar_t * n_input_data = input_data + offset;
640- scalar_t * n_result_data = result_data + offset;
641- for (const auto j : c10::irange (channel_size)) {
642- for (const auto k : c10::irange (input_stride1)) {
643- // to allow for compiler optimization, here splitting into two lines:
644- scalar_t w = (n_input_data[k] > 0 ) ? scalar_t (1 ) : weight_data[j];
645- n_result_data[k] = w * n_input_data[k];
646- }
647- n_input_data += input_stride1;
648- n_result_data += input_stride1;
649- }
650- }
651- };
652- if (input.numel () > 1000 ) {
653- at::parallel_for (0 , input_dim0_size, 0 , loop);
654- } else {
655- loop (0 , input_dim0_size);
656- }
657- }
658-
659600Tensor prelu_cpu (const Tensor& self, const Tensor& weight_) {
660- auto input = self.contiguous ();
661- auto weight = weight_.contiguous ();
662-
663- TORCH_CHECK (input.is_contiguous ());
664- TORCH_CHECK (weight.is_contiguous ());
601+ int64_t weight_num = weight_.numel ();
602+ Tensor result = at::empty_like (self, self.suggest_memory_format ());
665603
666- int64_t weight_num = weight.numel ();
667- Tensor result = at::empty_like (input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
668- auto strides = input.strides ();
669-
670- // case1: shared weight for all channels
671- if (weight_num == 1 ) {
672- AT_DISPATCH_FLOATING_TYPES (input.scalar_type (), " prelu_cpu" , [&] {
673- prelu_cpu_kernel_share_weights<scalar_t >(result, input, weight);
674- });
675- }
676- else { // case2: multiple weights, one for each channel
677- int64_t input_ndim = input.dim ();
604+ if (weight_num != 1 ) {
605+ int64_t input_ndim = self.dim ();
678606 TORCH_CHECK (input_ndim > 0 , " Not allow zero-dim input tensor." );
679607
680608 int64_t channel_size = 1 ; // channel_size default to 1
681- int64_t input_dim0_size = 1 , input_stride0 = 1 , input_stride1 = 1 ;
682-
683609 if (input_ndim > 1 ) {
684- channel_size = input.size (1 ); // channel is the 2nd dim of input
685- input_dim0_size = input.size (0 );
686- input_stride0 = strides[0 ];
687- input_stride1 = strides[1 ];
610+ channel_size = self.size (1 ); // channel is the 2nd dim of input
688611 }
689612 TORCH_CHECK (channel_size == weight_num,
690613 " Mismatch of parameter numbers and input channel size. Found parameter numbers = " , weight_num,
691614 " and channel size = " , channel_size, " ." );
692-
693- AT_DISPATCH_FLOATING_TYPES (input.scalar_type (), " prelu_cpu" , [&] {
694- prelu_cpu_kernel_multi_weights<scalar_t >(
695- result,
696- input,
697- weight,
698- input_dim0_size,
699- channel_size,
700- input_stride0,
701- input_stride1);
702- });
703615 }
704- return result;
705- }
706616
707- // -----------------------------------
708- // prelu backward
709- // -----------------------------------
710- template <typename scalar_t >
711- void inline prelu_cpu_backward_kernel_share_weights (
712- const Tensor& input,
713- const Tensor& weight,
714- const Tensor& grad_out,
715- Tensor& input_grad,
716- Tensor& weight_grad) {
717-
718- int64_t input_numel = input.numel ();
719- auto input_data = input.data_ptr <scalar_t >();
720- auto weight_val = weight.data_ptr <scalar_t >()[0 ];
721- auto grad_out_data = grad_out.data_ptr <scalar_t >();
722- auto input_grad_data = input_grad.data_ptr <scalar_t >();
723- auto weight_grad_data = weight_grad.data_ptr <scalar_t >();
724-
725- scalar_t sum = at::parallel_reduce (0 , input_numel, 1000 , scalar_t (0 ),
726- [&](int64_t start, int64_t end, scalar_t ident) -> scalar_t {
727- scalar_t partial_sum = ident;
728- for (const auto i : c10::irange (start, end)) {
729- scalar_t input_data_val = input_data[i];
730- scalar_t grad_out_data_val = grad_out_data[i];
731- // to allow for compiler optimization, here splitting into two lines:
732- scalar_t w = (input_data_val > 0 ) ? scalar_t (1 ) : weight_val;
733- input_grad_data[i] = w * grad_out_data_val;
734- // to allow for compiler optimization, here splitting into two lines:
735- scalar_t mask = (input_data_val > 0 ) ? scalar_t (0 ) : scalar_t (1 );
736- partial_sum += mask * input_data_val * grad_out_data_val;
737- }
738- return partial_sum;
739- }, std::plus<scalar_t >());
740- weight_grad_data[0 ] = sum;
741- }
742-
743- template <typename scalar_t >
744- void inline prelu_cpu_backward_kernel_multi_weights (
745- const Tensor& input,
746- const Tensor& weight,
747- const Tensor& grad_out,
748- Tensor& input_grad,
749- Tensor& weight_grad_collector,
750- int64_t input_dim0_size,
751- int64_t channel_size,
752- int64_t input_stride0,
753- int64_t input_stride1) {
754-
755- auto input_data = input.data_ptr <scalar_t >();
756- auto weight_data = weight.data_ptr <scalar_t >();
757- auto grad_out_data = grad_out.data_ptr <scalar_t >();
758- auto input_grad_data = input_grad.data_ptr <scalar_t >();
759- auto weight_grad_collector_data = weight_grad_collector.data_ptr <scalar_t >();
760-
761- auto loop = [&](int64_t start, int64_t end) {
762- for (const auto i : c10::irange (start, end)) {
763- for (const auto j : c10::irange (channel_size)) {
764- for (const auto k : c10::irange (input_stride1)) {
765- int64_t pos = i * input_stride0 + j * input_stride1 + k;
766- scalar_t weight_data_val = weight_data[j];
767- scalar_t input_data_val = input_data[pos];
768- scalar_t grad_out_data_val = grad_out_data[pos];
769- // to allow for compiler optimization, here splitting into two lines:
770- scalar_t w = (input_data_val > 0 ) ? scalar_t (1 ) : weight_data_val;
771- input_grad_data[pos] = w * grad_out_data_val;
772- // to allow for compiler optimization, here splitting into two lines:
773- scalar_t mask = (input_data_val > 0 ) ? scalar_t (0 ) : scalar_t (1 );
774- weight_grad_collector_data[pos] = mask * input_data_val * grad_out_data_val;
775- }
776- }
617+ const int64_t ndim = self.dim ();
618+ // Helper to convert 1d tensors or scalar tensor to an nd tensor that broadcasts with input
619+ // All elements go into the channel dimension
620+ DimVector sizes (ndim, 1 ), strides (ndim, 0 );
621+ auto as_nd = [&](const Tensor& t) {
622+ TORCH_INTERNAL_ASSERT (t.defined () && (t.dim () == 1 || t.dim () == 0 ));
623+ if (ndim >= 2 ) {
624+ sizes[1 ] = t.dim () == 1 ? t.sizes ()[0 ] : 1 ;
625+ strides[1 ] = t.dim () == 1 ? t.strides ()[0 ] : 0 ;
626+ return t.as_strided (sizes, strides);
777627 }
628+ return t.as_strided (sizes, strides);
778629 };
779- if (input.numel () > 1000 ) {
780- at::parallel_for (0 , input_dim0_size, 0 , loop);
630+ Tensor w;
631+ if (self.scalar_type () == ScalarType::BFloat16) {
632+ auto w_bf16 = at::empty (weight_.sizes (), weight_.options ().dtype (ScalarType::BFloat16));
633+ w_bf16.copy_ (weight_);
634+ w = weight_.defined () ? as_nd (w_bf16) :
635+ at::detail::scalar_tensor_static (1 , self.scalar_type (), kCPU );
781636 } else {
782- loop (0 , input_dim0_size);
637+ w = weight_.defined () ? as_nd (weight_) :
638+ at::detail::scalar_tensor_static (1 , self.scalar_type (), kCPU );
783639 }
640+
641+ auto iter = TensorIteratorConfig ()
642+ .add_output (result)
643+ .add_input (self)
644+ .add_input (w)
645+ .build ();
646+ prelu_cpu_stub (iter.device_type (), iter);
647+ return result;
784648}
785649
786650std::tuple<Tensor, Tensor> prelu_backward_cpu (const Tensor& grad_out_, const Tensor& self, const Tensor& weight_) {
787- auto input = self.contiguous ();
788- auto grad_out = grad_out_.contiguous ();
789- auto weight = weight_.contiguous ();
790-
791- TORCH_CHECK (input.is_contiguous ());
792- TORCH_CHECK (grad_out.is_contiguous ());
793- TORCH_CHECK (weight.is_contiguous ());
651+ int64_t weight_num = weight_.numel ();
794652
795- int64_t weight_num = weight. numel ( );
796- auto strides = input. strides ( );
797- auto dims = input. dim ( );
653+ Tensor input_grad = at::empty_like (self, self. suggest_memory_format () );
654+ Tensor weight_grad = at::empty_like (weight_, at::MemoryFormat::Contiguous );
655+ Tensor weight_grad_collector = at::empty_like (self, at::MemoryFormat::Contiguous );
798656
799- Tensor input_grad = at::empty_like (input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
800- Tensor weight_grad = at::empty_like (weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
801- Tensor weight_grad_collector = at::empty_like (input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
802-
803- // case1: shared parameter for all channels
804- if (weight_num == 1 ) {
805- AT_DISPATCH_FLOATING_TYPES (input.scalar_type (), " prelu_backward_cpu" , [&] {
806- prelu_cpu_backward_kernel_share_weights<scalar_t >(input, weight, grad_out, input_grad, weight_grad);
807- });
808- }
809- else { // case2: multiple parameters, one for each channel
810- int64_t input_ndim = input.dim ();
657+ if (weight_num != 1 ) {
658+ int64_t input_ndim = self.dim ();
811659 TORCH_CHECK (input_ndim > 0 , " Not allow zero-dim input tensor." );
812660
813661 int64_t channel_size = 1 ; // channel_size default to 1
814- int64_t input_dim0_size = 1 , input_stride0 = 1 , input_stride1 = 1 ;
815-
816662 if (input_ndim > 1 ) {
817- channel_size = input.size (1 ); // channel is the 2nd dim of input
818- input_dim0_size = input.size (0 );
819- input_stride0 = strides[0 ];
820- input_stride1 = strides[1 ];
663+ channel_size = self.size (1 ); // channel is the 2nd dim of input
821664 }
822665 TORCH_CHECK (channel_size == weight_num,
823666 " Mismatch of parameter numbers and input channel size. Found parameter numbers = " , weight_num,
824667 " and channel size = " , channel_size, " ." );
668+ }
825669
826- AT_DISPATCH_FLOATING_TYPES (input.scalar_type (), " prelu_backward_cpu" , [&] {
827- prelu_cpu_backward_kernel_multi_weights<scalar_t >(
828- input,
829- weight,
830- grad_out,
831- input_grad,
832- weight_grad_collector,
833- input_dim0_size,
834- channel_size,
835- input_stride0,
836- input_stride1);
837- });
670+ const int64_t ndim = self.dim ();
671+ // Helper to convert 1d tensor or scalar tensor to an nd tensor that broadcasts with input
672+ // All elements go into the channel dimension
673+ DimVector sizes (ndim, 1 ), strides (ndim, 0 );
674+ auto as_nd = [&](const Tensor& t) {
675+ TORCH_INTERNAL_ASSERT (t.defined () && (t.dim () == 1 || t.dim () == 0 ));
676+ if (ndim >= 2 ) {
677+ sizes[1 ] = t.dim () == 1 ? t.sizes ()[0 ] : 1 ;
678+ strides[1 ] = t.dim () == 1 ? t.strides ()[0 ] : 0 ;
679+ return t.as_strided (sizes, strides);
680+ }
681+ return t.as_strided (sizes, strides);
682+ };
683+ Tensor w;
684+ if (self.scalar_type () == ScalarType::BFloat16) {
685+ auto w_bf16 = at::empty (weight_.sizes (), weight_.options ().dtype (ScalarType::BFloat16));
686+ w_bf16.copy_ (weight_);
687+ w = weight_.defined () ? as_nd (w_bf16) :
688+ at::detail::scalar_tensor_static (1 , self.scalar_type (), kCPU );
689+ } else {
690+ w = weight_.defined () ? as_nd (weight_) :
691+ at::detail::scalar_tensor_static (1 , self.scalar_type (), kCPU );
692+ }
693+
694+ auto iter = TensorIteratorConfig ()
695+ .add_output (input_grad)
696+ .add_output (weight_grad_collector)
697+ .add_input (self)
698+ .add_input (grad_out_)
699+ .add_input (w)
700+ .build ();
701+
702+ prelu_backward_cpu_stub (iter.device_type (), iter);
703+
704+ if (weight_num == 1 ) {
705+ weight_grad.fill_ (weight_grad_collector.sum ());
706+ } else {
838707 // update weight_grad
839708 std::vector<int64_t > reduce_dims;
709+ int64_t input_ndim = self.dim ();
840710 reduce_dims.push_back (0 );
841- if (dims > 2 ) {
842- for (const auto i : c10::irange (2 , dims)) {
843- reduce_dims.push_back (i);
844- }
711+ if (input_ndim > 2 ) {
712+ for (int64_t i = 2 ; i < input_ndim; i++) reduce_dims.push_back (i);
845713 }
846714 weight_grad = weight_grad_collector.sum (reduce_dims);
847715 }
0 commit comments