Skip to content

[pt1][quant] Add QInt32 ScalarType and qint32 data type#19816

Closed
jerryzh168 wants to merge 24 commits intomasterfrom
export-D15094174
Closed

[pt1][quant] Add QInt32 ScalarType and qint32 data type#19816
jerryzh168 wants to merge 24 commits intomasterfrom
export-D15094174

Conversation

@jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Apr 27, 2019

Stack:
    :white_circle:  #20107 [pt1][quant] Add dequantize_linear for JIT pass  💚
    :white_circle:  #19984 [pt1][quant] Add qint8 type (int8_t)  💚
    :white_circle:  #19932 [pt1][quant] Rename qint8 data type  💚
    :black_circle:  #19816 [pt1][quant] Add QInt32 ScalarType and qint32 data type  💚

We need this for quantization for bias
add third argument of ScalarType to quantize_linear

Differential Revision: D15094174

Differential Revision: D15094174
Differential Version: 80838655
@pytorchbot pytorchbot added module: internals Related to internal abstractions in c10 and ATen module: operators oncall: quantization Quantization support in PyTorch labels Apr 27, 2019
@jerryzh168 jerryzh168 requested review from gchanan and li-roy April 27, 2019 00:23
@jerryzh168
Copy link
Contributor Author

@gchanan #6593 says it does not support tracing the dtype version of the function, does this still apply?

Differential Revision: D15094174
Differential Version: 80839141
Differential Revision: D15094174
Differential Version: 80909406
Differential Revision: D15094174
Differential Version: 80918409
Differential Revision: D15094174
Differential Version: 80977784
Differential Revision: D15094174
Differential Version: 80996728
Differential Revision: D15094174
Differential Version: 81007920
Differential Revision: D15094174
Differential Version: 81099685
Differential Revision: D15094174
Differential Version: 81173037
@jerryzh168
Copy link
Contributor Author

@dzhulgakov looks like we need to define the template function in all the callsite?

May 03 19:22:31 CMakeFiles/quantized_test.dir/__/aten/src/ATen/test/quantized_test.cpp.o: In function `TestQTensor_QuantDequantAPIs_Test::TestBody()':
May 03 19:22:31 quantized_test.cpp:(.text+0x1f56): undefined reference to `c10::qint8 at::quantize_val<c10::qint8>(float, int, float)'

@jerryzh168 jerryzh168 requested review from dzhulgakov and z-a-f May 4, 2019 00:38
jerryzh168 added 5 commits May 6, 2019 11:08
Differential Revision: D15094174
Differential Version: 81239113
Differential Revision: D15094174
Differential Version: 81253515
Differential Revision: D15094174
Differential Version: 81263921
Differential Revision: D15094174
Differential Version: 81319070
Differential Revision: D15094174
Differential Version: 81328044
@jerryzh168
Copy link
Contributor Author

looks like the CI is passing, @dzhulgakov @gchanan @raghuramank100 please review again

Differential Revision: D15094174
Differential Version: 81380836
for (auto i = 0; i < qtensor.numel(); ++i) {
// We need to convert the qint8 value to float to ensure the subtraction
// subexpression returns a float
rd[i] = (static_cast<float>(qd[i].val_) - zero_point) * scale;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why force either of these to be contiguous?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now I'm just assuming everything is contiguous, we can change it later if there is a need

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not assuming, we have .contiguous call before for the input tensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but we already have a bunch of code for copying and and for doing type conversions. Did you check what that code does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean TensorIterators?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since fbgemm implementation is contiguous only (and we expect that one to be enabled by default), I think it's ok to force contiguous

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems fine to make the output contiguous (although you still need to check in the ops themselves, because unlike say, MKLDNN, there is no guarantee you didn't do a view after getting the contiguous tensor).

But this forces the input to be contiguous too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I still would like to understand how the existing copy/type conversion mechanisms fit into this.

For example, we almost certainly want memory_order support for quantized tensors, right? If it's possible to use the same mechanism, we should get it "for free", but in this way we certainly won't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a separate discussion. Also I don't have enough context on how we should handle copy/type conversion for qtensor yet.. I do plan to implement permute and view on QTensor though.

Differential Revision: D15094174
Differential Version: 81411153
Differential Revision: D15094174
Differential Version: 81423512
@jerryzh168 jerryzh168 requested a review from gchanan May 9, 2019 01:11
jerryzh168 added 2 commits May 9, 2019 15:04
Differential Revision: D15094174
Differential Version: 81476388
Differential Revision: D15094174
Differential Version: 81478101
@jerryzh168 jerryzh168 requested a review from gchanan May 9, 2019 22:30
Differential Revision: D15094174
Differential Version: 81496498
Differential Revision: D15094174
Differential Version: 81553389
Differential Revision: D15094174
Differential Version: 81642142
Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good modulo a few comments

for (auto i = 0; i < qtensor.numel(); ++i) {
// We need to convert the qint8 value to float to ensure the subtraction
// subexpression returns a float
rd[i] = (static_cast<float>(qd[i].val_) - zero_point) * scale;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since fbgemm implementation is contiguous only (and we expect that one to be enabled by default), I think it's ok to force contiguous

struct CAFFE2_API Quantizer : public c10::intrusive_ptr_target {
const QScheme qscheme_;
explicit Quantizer(QScheme qscheme) : qscheme_(qscheme) {}
const ScalarType scalar_type_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to keep scalar type in quantizer if it's already present in the tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantize function in Quantizer only takes a float Tensor as argument, so we need to have this info in Quantizer as well, scalar type is in the output Tensor not input Tensor.

Differential Revision: D15094174
Differential Version: 81670049
Differential Revision: D15094174
Differential Version: 81728373
@pytorchbot pytorchbot added the module: nn Related to torch.nn label May 14, 2019
zdevito pushed a commit to zdevito/ATen that referenced this pull request May 16, 2019
Summary:
Pull Request resolved: pytorch/pytorch#19816

We need this for quantization for bias
add third argument of ScalarType to `quantize_linear`

Differential Revision: D15094174

fbshipit-source-id: f19ec8f4716cf5fe0aa21b38d45af6d27c9ab377
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in abb3698.

facebook-github-bot pushed a commit that referenced this pull request May 23, 2019
Summary:
Close #20642

Possibly broken by #19816
Pull Request resolved: #20853

Differential Revision: D15474620

Pulled By: jerryzh168

fbshipit-source-id: 99b52d92a93bac7cab52537f1ebdbd286d4b2cfe
@ezyang ezyang deleted the export-D15094174 branch May 30, 2019 16:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn oncall: quantization Quantization support in PyTorch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants