Skip to content

Lower quant/dequant torch op to StableHLO#5763

Merged
lsy323 merged 23 commits intomasterfrom
lsiyuan/quant-dequant-dispatch
Nov 28, 2023
Merged

Lower quant/dequant torch op to StableHLO#5763
lsy323 merged 23 commits intomasterfrom
lsiyuan/quant-dequant-dispatch

Conversation

@lsy323
Copy link
Copy Markdown
Collaborator

@lsy323 lsy323 commented Nov 2, 2023

The following torch ops can be lowered to StableHLO with this diff:

  • quantize_per_tensor
  • quantize_per_channel
  • dequantize_per_tensor
  • dequantize_per_channel

User Experience

  • The GraphModule generated from PT2E quantization can be exported to StableHLO, or tf.saved_model using the existing exporting API without any additional change on model code, or exporting script. STABLEHLO_BYTECODE_FROM_PRETTYPRINT needs to be set to 1 to workaround a StableHLO bytecode serialization issue.

Current workflow

  1. Register xla qdq ops to 'XLA' dispatch key. So the qdq ops will be dispatched to xla impl during LTC tracing.
  2. During lowering, qdq ops are lowered to a custom call to stablehlo.uniform_quantize/dequantize in HLO. The qparams are stored in the custom call config str. The config str can be deserialized to mlir DictAttr directly.
  3. HLO->StableHLO converter will convert custom call to stablehlo.uniform_quantize/dequantize

Changes

  1. Allow save_torch_module_as_tf_saved_model to take GraphModule as well, since PT2E outputs a GraphModule.
  2. Added 2 patches. One is to add support to HLO->StableHLO converter for stablehlo.uniform_quantize/dequantize conversion, originally authored by @sdasgup3. Another is to workaround a StableHLO bytecode serialization issue mentioned above. Both won't be needed if HLO qdtype representation is added.
  3. Added new xla quantize_tensor/dequantize_tensor ops for qdq ops lowering. the xla quantize/dequantize op lowers to custom call to stablehlo.uniform_quantize/dequantize
  4. Test script including exporting per-tensor/channel qdq ops and PT2E quantized resnet18 model.

Future Work

  1. When qdtype is added to HLO, the lowering logic need to be updated and will be more concise than the current one.

cc @sdasgup3 @GleasonK @paulinesho

@lsy323 lsy323 requested review from JackCaoG, miladm and qihqi November 2, 2023 16:59
Comment thread WORKSPACE
@lsy323 lsy323 force-pushed the lsiyuan/quant-dequant-dispatch branch from e70be80 to c329b63 Compare November 18, 2023 00:39
@lsy323 lsy323 requested a review from JackCaoG November 27, 2023 18:19
Comment thread torch_xla/tf_saved_model_integration.py Outdated
Comment thread torch_xla/csrc/runtime/stablehlo_helper.cc
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m = capture_pre_autograd_graph(m, args)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we use this instead of torch.export?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I saw the export below, but still confuse what this function does to the module.

Copy link
Copy Markdown
Collaborator Author

@lsy323 lsy323 Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the graph is captured for PT2E to further process. PT2E doesn't work with graph captured from torch.exported (just tried locally), it needs to capture the graph in this way.

The export down below is for PyTorch -> StableHLO exporting, our API only works on exported program

Comment thread test/stablehlo/test_pt2e_qdq.py Outdated
Comment thread torch_xla/csrc/init_python_bindings.cpp
Comment thread torch_xla/csrc/ops/dequant_tensor.h Outdated
Comment thread torch_xla/csrc/quant_util.h
Comment thread test/stablehlo/test_pt2e_qdq.py Outdated
Comment thread test/stablehlo/test_pt2e_qdq.py Outdated
@lsy323
Copy link
Copy Markdown
Collaborator Author

lsy323 commented Nov 28, 2023

Update:

  • Addressed review comments
  • Enhanced testing script to check the qparam of qdq stablehlo ops, numbers of qdq ops
  • Added more assertions to the torch_xla qdq ops, including scale, zero_point shape, zero_point dtype matches int dtype of quantized type, scale values are all positive

@lsy323 lsy323 requested a review from JackCaoG November 28, 2023 07:50
@lsy323 lsy323 added the stablehlo StableHLO related work label Nov 28, 2023
@lsy323 lsy323 merged commit a3b0c6e into master Nov 28, 2023
Comment thread test/stablehlo/test_pt2e_qdq.py
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
@lsy323 lsy323 deleted the lsiyuan/quant-dequant-dispatch branch March 4, 2024 19:12
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stablehlo StableHLO related work

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants