Speed up HistogramObserver by vectorizing critical path#41041
Speed up HistogramObserver by vectorizing critical path#41041durumu wants to merge 6 commits intogh/durumu/9/basefrom
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 11ce38b (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 17 times. |
|
nice! In general looks good, can we just add to the test plan:
|
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
| norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) | ||
| return norm | ||
|
|
||
| src_bin = torch.arange(self.bins).numpy() |
There was a problem hiding this comment.
PyTorch doesn't have a NumPy dependency for its functionality (although we do for some tests), and we shouldn't use NumPy functionality in lieu of our own. Uses of NumPy should be restricted to testing and NumPy interop.
There was a problem hiding this comment.
Thanks for the feedback -- I changed my code to get rid of the numpy dependency.
| delta_end = src_bin_end - dst_bin_of_end_center | ||
| norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) | ||
| return norm | ||
|
|
There was a problem hiding this comment.
This can be optimized further by the following approximation:
Quantization error = (StepSize^2/12)Q + sum(P[i](BinCenter[i]-next_start_bin)^2) + sum(Pi]*(BinCenter[i] - end_start_bin)^2).
Q = sum(hist[next_start_bin:next_end_bin])
where the first sum is over the bins less than the start_bin and the second sum is over bins greater than the end bin. In this approximation, we only need to compute two indices: Where do the next_start_bin and next_end_bin map to in terms of the original histogram indices
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
Differential Revision: [D22400755](https://our.internmc.facebook.com/intern/diff/D22400755) [ghstack-poisoned]
Summary: 22x speedup over the code this replaces. Tested on ResNet18 on a devvm using CPU only, using default parameters for HistogramObserver (i.e. 2048 bins). Pull Request resolved: pytorch#41041 Test Plan: To run the test against the reference (old) implementation, you can use `python test/test_quantization.py TestRecordHistogramObserver.test_histogram_observer_against_reference`. To run the benchmark, while in the folder `benchmarks/operator_benchmark`, you can use `python -m benchmark_all_quantized_test --operators HistogramObserverCalculateQparams`. Benchmark results before speedup: ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: HistogramObserverCalculateQparams # Mode: Eager # Name: HistogramObserverCalculateQparams_C3_M512_N512_dtypetorch.quint8_cpu_qschemetorch.per_tensor_affine # Input: C: 3, M: 512, N: 512, dtype: torch.quint8, device: cpu, qscheme: torch.per_tensor_affine Forward Execution Time (us) : 185818.566 # Benchmarking PyTorch: HistogramObserverCalculateQparams # Mode: Eager # Name: HistogramObserverCalculateQparams_C3_M512_N512_dtypetorch.quint8_cpu_qschemetorch.per_tensor_symmetric # Input: C: 3, M: 512, N: 512, dtype: torch.quint8, device: cpu, qscheme: torch.per_tensor_symmetric Forward Execution Time (us) : 165325.916 ``` Benchmark results after speedup: ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : short # Benchmarking PyTorch: HistogramObserverCalculateQparams # Mode: Eager # Name: HistogramObserverCalculateQparams_C3_M512_N512_dtypetorch.quint8_cpu_qschemetorch.per_tensor_affine # Input: C: 3, M: 512, N: 512, dtype: torch.quint8, device: cpu, qscheme: torch.per_tensor_affine Forward Execution Time (us) : 12242.241 # Benchmarking PyTorch: HistogramObserverCalculateQparams # Mode: Eager # Name: HistogramObserverCalculateQparams_C3_M512_N512_dtypetorch.quint8_cpu_qschemetorch.per_tensor_symmetric # Input: C: 3, M: 512, N: 512, dtype: torch.quint8, device: cpu, qscheme: torch.per_tensor_symmetric Forward Execution Time (us) : 12655.354 ``` Reviewed By: raghuramank100 Differential Revision: D22400755 Pulled By: durumu fbshipit-source-id: 639ac796a554710a33c8a930c1feae95a1148718
Roughly a 22x speedup over the code this replaces when tested on ResNet18 on a devvm using CPU only, using default parameters for HistogramObserver (i.e. 2048 bins). The script I ran to test this is here.
Roughly a 14x speedup when tested using the benchmark from #42138 (also CPU only).
Stack from ghstack:
Differential Revision: D22400755