Skip to content

Commit 03f493a

Browse files
yanli924facebook-github-bot
authored andcommitted
update HistogramObserver to be scriptable (#51001)
Summary: Pull Request resolved: #51001 fix tests in TestQuantizeJitOps Test Plan: Imported from OSS python test/test_quantization.py Reviewed By: raghuramank100 Differential Revision: D26038759 Pulled By: lyoka fbshipit-source-id: c38bbe81273654795451366fa72d76453954fad9
1 parent 5adbace commit 03f493a

3 files changed

Lines changed: 114 additions & 77 deletions

File tree

test/quantization/test_workflow_module.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
FixedQParamsFakeQuantize,
1414
default_debug_qconfig,
1515
default_observer,
16+
default_histogram_observer,
1617
default_per_channel_weight_observer,
1718
default_affine_fixed_qparams_fake_quant,
1819
get_observer_dict,
@@ -696,6 +697,29 @@ def test_observer_scriptable(self, qdtype, qscheme):
696697
loaded = torch.jit.load(buf)
697698
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))
698699

700+
class TestHistogramObserver(QuantizationTestCase):
701+
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
702+
qscheme=st.sampled_from(
703+
(torch.per_tensor_affine, torch.per_tensor_symmetric))
704+
)
705+
def test_observer_scriptable(self, qdtype, qscheme):
706+
ob_list = [
707+
HistogramObserver(dtype=qdtype, qscheme=qscheme),
708+
default_histogram_observer()
709+
]
710+
for obs in ob_list:
711+
scripted = torch.jit.script(obs)
712+
713+
x = torch.rand(3, 4)
714+
obs(x)
715+
scripted(x)
716+
self.assertTrue(torch.equal(obs.histogram, scripted.histogram))
717+
buf = io.BytesIO()
718+
torch.jit.save(scripted, buf)
719+
buf.seek(0)
720+
loaded = torch.jit.load(buf)
721+
self.assertTrue(torch.equal(obs.histogram, scripted.histogram))
722+
699723
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
700724
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
701725
reduce_range=st.booleans())

test/test_quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
# TODO: merge with TestObserver
3838
# TODO: some tests belong to test_quantize.py, e.g. test_record_observer
3939
from quantization.test_workflow_module import TestRecordHistogramObserver # noqa: F401
40+
from quantization.test_workflow_module import TestHistogramObserver # noqa: F401
4041
from quantization.test_workflow_module import TestDistributed # noqa: F401
4142

4243
# Workflow

torch/quantization/observer.py

Lines changed: 89 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -697,8 +697,14 @@ class HistogramObserver(_ObserverBase):
697697
min_val: torch.Tensor
698698
max_val: torch.Tensor
699699

700-
def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8,
701-
qscheme=torch.per_tensor_affine, reduce_range=False):
700+
def __init__(
701+
self,
702+
bins: int = 2048,
703+
upsample_rate: int = 128,
704+
dtype: torch.dtype = torch.quint8,
705+
qscheme=torch.per_tensor_affine,
706+
reduce_range=False
707+
):
702708
# bins: The number of bins used for histogram calculation.
703709
super(HistogramObserver, self).__init__(dtype=dtype,
704710
qscheme=qscheme,
@@ -710,83 +716,87 @@ def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8,
710716
self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
711717
self.upsample_rate = upsample_rate
712718

713-
@torch.jit.ignore
714-
def _non_linear_param_search(self):
719+
def _get_norm(
720+
self,
721+
delta_begin: torch.Tensor,
722+
delta_end: torch.Tensor,
723+
density: torch.Tensor
724+
) -> torch.Tensor:
725+
r"""
726+
Compute the norm of the values uniformaly distributed between
727+
delta_begin and delta_end.
728+
Currently only L2 norm is supported.
729+
730+
norm = density * (integral_{begin, end} x^2)
731+
= density * (end^3 - begin^3) / 3
732+
"""
733+
norm = (
734+
delta_end * delta_end * delta_end
735+
- delta_begin * delta_begin * delta_begin
736+
) / 3
737+
return density * norm
738+
739+
def _compute_quantization_error(
740+
self, next_start_bin: int, next_end_bin: int
741+
):
742+
r"""
743+
Compute the quantization error if we use start_bin to end_bin as the
744+
min and max to do the quantization.
745+
"""
746+
bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
747+
748+
dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
749+
if dst_bin_width == 0.0:
750+
return 0.0
751+
752+
src_bin = torch.arange(self.bins)
753+
# distances from the beginning of first dst_bin to the beginning and
754+
# end of src_bin
755+
src_bin_begin = (src_bin - next_start_bin) * bin_width
756+
src_bin_end = src_bin_begin + bin_width
757+
758+
# which dst_bins the beginning and end of src_bin belong to?
759+
dst_bin_of_begin = torch.clamp(src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1)
760+
dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width
761+
762+
dst_bin_of_end = torch.clamp(src_bin_end // dst_bin_width, 0, self.dst_nbins - 1)
763+
dst_bin_of_end_center = (dst_bin_of_end + 0.5) * dst_bin_width
764+
765+
density = self.histogram / bin_width
766+
767+
norm = torch.zeros(self.bins)
768+
769+
delta_begin = src_bin_begin - dst_bin_of_begin_center
770+
delta_end = dst_bin_width / 2
771+
norm += self._get_norm(delta_begin, torch.ones(self.bins) * delta_end, density)
772+
773+
norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm(
774+
torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density
775+
)
776+
777+
dst_bin_of_end_center = (
778+
dst_bin_of_end * dst_bin_width + dst_bin_width / 2
779+
)
780+
781+
delta_begin = -dst_bin_width / 2
782+
delta_end = src_bin_end - dst_bin_of_end_center
783+
norm += self._get_norm(torch.tensor(delta_begin), delta_end, density)
784+
785+
return norm.sum().item()
786+
787+
def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]:
715788
r"""Non-linear parameter search.
716789
717790
An approximation for L2 error minimization for selecting min/max.
718791
By selecting new min/max, we filter out outliers in input distribution.
719792
This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
720793
caffe2/quantization/server/norm_minimization.cc
721794
"""
722-
def _get_norm(delta_begin, delta_end, density, norm_type):
723-
r"""
724-
Compute the norm of the values uniformaly distributed between
725-
delta_begin and delta_end.
726-
727-
norm = density * (integral_{begin, end} x^2)
728-
= density * (end^3 - begin^3) / 3
729-
"""
730-
assert norm_type == "L2", "Only L2 norms are currently supported"
731-
norm = 0.0
732-
if norm_type == "L2":
733-
norm = (
734-
delta_end * delta_end * delta_end
735-
- delta_begin * delta_begin * delta_begin
736-
) / 3
737-
return density * norm
738-
739-
def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
740-
r"""
741-
Compute the quantization error if we use start_bin to end_bin as the
742-
min and max to do the quantization.
743-
"""
744-
bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
745-
746-
dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
747-
if dst_bin_width == 0.0:
748-
return 0.0
749-
750-
src_bin = torch.arange(self.bins)
751-
# distances from the beginning of first dst_bin to the beginning and
752-
# end of src_bin
753-
src_bin_begin = (src_bin - next_start_bin) * bin_width
754-
src_bin_end = src_bin_begin + bin_width
755-
756-
# which dst_bins the beginning and end of src_bin belong to?
757-
dst_bin_of_begin = torch.clamp(src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1)
758-
dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width
759-
760-
dst_bin_of_end = torch.clamp(src_bin_end // dst_bin_width, 0, self.dst_nbins - 1)
761-
dst_bin_of_end_center = (dst_bin_of_end + 0.5) * dst_bin_width
762-
763-
density = self.histogram / bin_width
764-
765-
norm = torch.zeros(self.bins)
766-
767-
delta_begin = src_bin_begin - dst_bin_of_begin_center
768-
delta_end = dst_bin_width / 2
769-
norm += _get_norm(delta_begin, delta_end, density, norm_type)
770-
771-
norm += (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
772-
-dst_bin_width / 2, dst_bin_width / 2, density, norm_type
773-
)
774-
775-
dst_bin_of_end_center = (
776-
dst_bin_of_end * dst_bin_width + dst_bin_width / 2
777-
)
778-
779-
delta_begin = -dst_bin_width / 2
780-
delta_end = src_bin_end - dst_bin_of_end_center
781-
norm += _get_norm(delta_begin, delta_end, density, norm_type)
782-
783-
return norm.sum()
784-
785795
assert self.histogram.size()[0] == self.bins, "bins mistmatch"
786796
bin_width = (self.max_val - self.min_val) / self.bins
787797

788798
# cumulative sum
789-
total = sum(self.histogram)
799+
total = torch.sum(self.histogram).item()
790800
cSum = torch.cumsum(self.histogram, dim=0)
791801

792802
stepsize = 1e-5 # granularity
@@ -825,7 +835,7 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
825835
continue
826836

827837
# calculate the quantization error using next_start_bin and next_end_bin
828-
norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
838+
norm = self._compute_quantization_error(next_start_bin, next_end_bin)
829839

830840
if norm > norm_min:
831841
break
@@ -837,27 +847,28 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
837847
new_max = self.min_val + bin_width * (end_bin + 1)
838848
return new_min, new_max
839849

840-
@torch.jit.ignore
841-
def _adjust_min_max(self,
842-
combined_min: torch.Tensor,
843-
combined_max: torch.Tensor,
844-
upsample_rate: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
850+
def _adjust_min_max(
851+
self,
852+
combined_min: torch.Tensor,
853+
combined_max: torch.Tensor,
854+
upsample_rate: int
855+
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
845856
# We ensure that:
846857
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
847858
# This allows us to have a common grid of resolution s, where we can align
848859
# the input histogram
849860
# start_idx maps min_val to the histogram bin index.
850861

851862
hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate)
852-
downsample_rate = int(torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).item())
863+
downsample_rate = int(torch.ceil(
864+
(combined_max - combined_min) / (self.bins * hist_bin_width)).item())
853865
e = downsample_rate * (self.bins * hist_bin_width) - (combined_max - combined_min)
854866
# Relax only the max, not the min, so that for one sided distributions, min stays at zero
855867
combined_max = combined_max + e
856868
combined_min = combined_min
857869
start_idx = int(torch.round((self.min_val - combined_min) / hist_bin_width).item())
858870
return combined_min, combined_max, downsample_rate, start_idx
859871

860-
@torch.jit.ignore
861872
def _combine_histograms(self,
862873
orig_hist: torch.Tensor,
863874
new_hist: torch.Tensor,
@@ -915,7 +926,8 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
915926
assert combined_min.numel() == 1 and combined_max.numel() == 1, (
916927
"histogram min/max values must be scalar."
917928
)
918-
combined_histogram = torch.histc(x, self.bins, min=int(combined_min), max=int(combined_max))
929+
combined_histogram = torch.histc(
930+
x, self.bins, min=int(combined_min), max=int(combined_max))
919931
if combined_min == min_val and combined_max == max_val:
920932
combined_histogram += self.histogram
921933
else:

0 commit comments

Comments
 (0)