@@ -697,8 +697,14 @@ class HistogramObserver(_ObserverBase):
697697 min_val : torch .Tensor
698698 max_val : torch .Tensor
699699
700- def __init__ (self , bins = 2048 , upsample_rate = 128 , dtype = torch .quint8 ,
701- qscheme = torch .per_tensor_affine , reduce_range = False ):
700+ def __init__ (
701+ self ,
702+ bins : int = 2048 ,
703+ upsample_rate : int = 128 ,
704+ dtype : torch .dtype = torch .quint8 ,
705+ qscheme = torch .per_tensor_affine ,
706+ reduce_range = False
707+ ):
702708 # bins: The number of bins used for histogram calculation.
703709 super (HistogramObserver , self ).__init__ (dtype = dtype ,
704710 qscheme = qscheme ,
@@ -710,83 +716,87 @@ def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8,
710716 self .dst_nbins = 2 ** torch .iinfo (self .dtype ).bits
711717 self .upsample_rate = upsample_rate
712718
713- @torch .jit .ignore
714- def _non_linear_param_search (self ):
719+ def _get_norm (
720+ self ,
721+ delta_begin : torch .Tensor ,
722+ delta_end : torch .Tensor ,
723+ density : torch .Tensor
724+ ) -> torch .Tensor :
725+ r"""
726+ Compute the norm of the values uniformaly distributed between
727+ delta_begin and delta_end.
728+ Currently only L2 norm is supported.
729+
730+ norm = density * (integral_{begin, end} x^2)
731+ = density * (end^3 - begin^3) / 3
732+ """
733+ norm = (
734+ delta_end * delta_end * delta_end
735+ - delta_begin * delta_begin * delta_begin
736+ ) / 3
737+ return density * norm
738+
739+ def _compute_quantization_error (
740+ self , next_start_bin : int , next_end_bin : int
741+ ):
742+ r"""
743+ Compute the quantization error if we use start_bin to end_bin as the
744+ min and max to do the quantization.
745+ """
746+ bin_width = (self .max_val .item () - self .min_val .item ()) / self .bins
747+
748+ dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1 ) / self .dst_nbins
749+ if dst_bin_width == 0.0 :
750+ return 0.0
751+
752+ src_bin = torch .arange (self .bins )
753+ # distances from the beginning of first dst_bin to the beginning and
754+ # end of src_bin
755+ src_bin_begin = (src_bin - next_start_bin ) * bin_width
756+ src_bin_end = src_bin_begin + bin_width
757+
758+ # which dst_bins the beginning and end of src_bin belong to?
759+ dst_bin_of_begin = torch .clamp (src_bin_begin // dst_bin_width , 0 , self .dst_nbins - 1 )
760+ dst_bin_of_begin_center = (dst_bin_of_begin + 0.5 ) * dst_bin_width
761+
762+ dst_bin_of_end = torch .clamp (src_bin_end // dst_bin_width , 0 , self .dst_nbins - 1 )
763+ dst_bin_of_end_center = (dst_bin_of_end + 0.5 ) * dst_bin_width
764+
765+ density = self .histogram / bin_width
766+
767+ norm = torch .zeros (self .bins )
768+
769+ delta_begin = src_bin_begin - dst_bin_of_begin_center
770+ delta_end = dst_bin_width / 2
771+ norm += self ._get_norm (delta_begin , torch .ones (self .bins ) * delta_end , density )
772+
773+ norm += (dst_bin_of_end - dst_bin_of_begin - 1 ) * self ._get_norm (
774+ torch .tensor (- dst_bin_width / 2 ), torch .tensor (dst_bin_width / 2 ), density
775+ )
776+
777+ dst_bin_of_end_center = (
778+ dst_bin_of_end * dst_bin_width + dst_bin_width / 2
779+ )
780+
781+ delta_begin = - dst_bin_width / 2
782+ delta_end = src_bin_end - dst_bin_of_end_center
783+ norm += self ._get_norm (torch .tensor (delta_begin ), delta_end , density )
784+
785+ return norm .sum ().item ()
786+
787+ def _non_linear_param_search (self ) -> Tuple [torch .Tensor , torch .Tensor ]:
715788 r"""Non-linear parameter search.
716789
717790 An approximation for L2 error minimization for selecting min/max.
718791 By selecting new min/max, we filter out outliers in input distribution.
719792 This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
720793 caffe2/quantization/server/norm_minimization.cc
721794 """
722- def _get_norm (delta_begin , delta_end , density , norm_type ):
723- r"""
724- Compute the norm of the values uniformaly distributed between
725- delta_begin and delta_end.
726-
727- norm = density * (integral_{begin, end} x^2)
728- = density * (end^3 - begin^3) / 3
729- """
730- assert norm_type == "L2" , "Only L2 norms are currently supported"
731- norm = 0.0
732- if norm_type == "L2" :
733- norm = (
734- delta_end * delta_end * delta_end
735- - delta_begin * delta_begin * delta_begin
736- ) / 3
737- return density * norm
738-
739- def _compute_quantization_error (next_start_bin , next_end_bin , norm_type ):
740- r"""
741- Compute the quantization error if we use start_bin to end_bin as the
742- min and max to do the quantization.
743- """
744- bin_width = (self .max_val .item () - self .min_val .item ()) / self .bins
745-
746- dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1 ) / self .dst_nbins
747- if dst_bin_width == 0.0 :
748- return 0.0
749-
750- src_bin = torch .arange (self .bins )
751- # distances from the beginning of first dst_bin to the beginning and
752- # end of src_bin
753- src_bin_begin = (src_bin - next_start_bin ) * bin_width
754- src_bin_end = src_bin_begin + bin_width
755-
756- # which dst_bins the beginning and end of src_bin belong to?
757- dst_bin_of_begin = torch .clamp (src_bin_begin // dst_bin_width , 0 , self .dst_nbins - 1 )
758- dst_bin_of_begin_center = (dst_bin_of_begin + 0.5 ) * dst_bin_width
759-
760- dst_bin_of_end = torch .clamp (src_bin_end // dst_bin_width , 0 , self .dst_nbins - 1 )
761- dst_bin_of_end_center = (dst_bin_of_end + 0.5 ) * dst_bin_width
762-
763- density = self .histogram / bin_width
764-
765- norm = torch .zeros (self .bins )
766-
767- delta_begin = src_bin_begin - dst_bin_of_begin_center
768- delta_end = dst_bin_width / 2
769- norm += _get_norm (delta_begin , delta_end , density , norm_type )
770-
771- norm += (dst_bin_of_end - dst_bin_of_begin - 1 ) * _get_norm (
772- - dst_bin_width / 2 , dst_bin_width / 2 , density , norm_type
773- )
774-
775- dst_bin_of_end_center = (
776- dst_bin_of_end * dst_bin_width + dst_bin_width / 2
777- )
778-
779- delta_begin = - dst_bin_width / 2
780- delta_end = src_bin_end - dst_bin_of_end_center
781- norm += _get_norm (delta_begin , delta_end , density , norm_type )
782-
783- return norm .sum ()
784-
785795 assert self .histogram .size ()[0 ] == self .bins , "bins mistmatch"
786796 bin_width = (self .max_val - self .min_val ) / self .bins
787797
788798 # cumulative sum
789- total = sum (self .histogram )
799+ total = torch . sum (self .histogram ). item ( )
790800 cSum = torch .cumsum (self .histogram , dim = 0 )
791801
792802 stepsize = 1e-5 # granularity
@@ -825,7 +835,7 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
825835 continue
826836
827837 # calculate the quantization error using next_start_bin and next_end_bin
828- norm = _compute_quantization_error (next_start_bin , next_end_bin , "L2" )
838+ norm = self . _compute_quantization_error (next_start_bin , next_end_bin )
829839
830840 if norm > norm_min :
831841 break
@@ -837,27 +847,28 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
837847 new_max = self .min_val + bin_width * (end_bin + 1 )
838848 return new_min , new_max
839849
840- @torch .jit .ignore
841- def _adjust_min_max (self ,
842- combined_min : torch .Tensor ,
843- combined_max : torch .Tensor ,
844- upsample_rate : int ) -> Tuple [torch .Tensor , torch .Tensor , int , int ]:
850+ def _adjust_min_max (
851+ self ,
852+ combined_min : torch .Tensor ,
853+ combined_max : torch .Tensor ,
854+ upsample_rate : int
855+ ) -> Tuple [torch .Tensor , torch .Tensor , int , int ]:
845856 # We ensure that:
846857 # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
847858 # This allows us to have a common grid of resolution s, where we can align
848859 # the input histogram
849860 # start_idx maps min_val to the histogram bin index.
850861
851862 hist_bin_width = (self .max_val - self .min_val ) / (self .bins * upsample_rate )
852- downsample_rate = int (torch .ceil ((combined_max - combined_min ) / (self .bins * hist_bin_width )).item ())
863+ downsample_rate = int (torch .ceil (
864+ (combined_max - combined_min ) / (self .bins * hist_bin_width )).item ())
853865 e = downsample_rate * (self .bins * hist_bin_width ) - (combined_max - combined_min )
854866 # Relax only the max, not the min, so that for one sided distributions, min stays at zero
855867 combined_max = combined_max + e
856868 combined_min = combined_min
857869 start_idx = int (torch .round ((self .min_val - combined_min ) / hist_bin_width ).item ())
858870 return combined_min , combined_max , downsample_rate , start_idx
859871
860- @torch .jit .ignore
861872 def _combine_histograms (self ,
862873 orig_hist : torch .Tensor ,
863874 new_hist : torch .Tensor ,
@@ -915,7 +926,8 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
915926 assert combined_min .numel () == 1 and combined_max .numel () == 1 , (
916927 "histogram min/max values must be scalar."
917928 )
918- combined_histogram = torch .histc (x , self .bins , min = int (combined_min ), max = int (combined_max ))
929+ combined_histogram = torch .histc (
930+ x , self .bins , min = int (combined_min ), max = int (combined_max ))
919931 if combined_min == min_val and combined_max == max_val :
920932 combined_histogram += self .histogram
921933 else :
0 commit comments