Lower quant/dequant torch op to StableHLO#5763
Merged
Conversation
JackCaoG
reviewed
Nov 2, 2023
lowering to HLO custom call.
refactor add quant util rename test script
clean up quant op
e70be80 to
c329b63
Compare
qihqi
approved these changes
Nov 27, 2023
JackCaoG
reviewed
Nov 27, 2023
| # Step 1: export resnet18 | ||
| args = (torch.randn(1, 3, 224, 224),) | ||
| m = torchvision.models.resnet18().eval() | ||
| m = capture_pre_autograd_graph(m, args) |
Collaborator
There was a problem hiding this comment.
is there a reason we use this instead of torch.export?
Collaborator
There was a problem hiding this comment.
ok I saw the export below, but still confuse what this function does to the module.
Collaborator
Author
There was a problem hiding this comment.
Here the graph is captured for PT2E to further process. PT2E doesn't work with graph captured from torch.exported (just tried locally), it needs to capture the graph in this way.
The export down below is for PyTorch -> StableHLO exporting, our API only works on exported program
JackCaoG
reviewed
Nov 27, 2023
JackCaoG
reviewed
Nov 27, 2023
JackCaoG
reviewed
Nov 27, 2023
JackCaoG
reviewed
Nov 27, 2023
JackCaoG
reviewed
Nov 27, 2023
JackCaoG
reviewed
Nov 27, 2023
Collaborator
Author
|
Update:
|
JackCaoG
approved these changes
Nov 28, 2023
miladm
reviewed
Dec 1, 2023
ManfeiBai
pushed a commit
to ManfeiBai/PyTorchXLA
that referenced
this pull request
Dec 1, 2023
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
ManfeiBai
pushed a commit
to ManfeiBai/PyTorchXLA
that referenced
this pull request
Dec 1, 2023
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
chunnienc
pushed a commit
to chunnienc/xla
that referenced
this pull request
Dec 14, 2023
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
golechwierowicz
pushed a commit
that referenced
this pull request
Jan 12, 2024
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
bhavya01
pushed a commit
that referenced
this pull request
Apr 22, 2024
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The following torch ops can be lowered to StableHLO with this diff:
User Experience
STABLEHLO_BYTECODE_FROM_PRETTYPRINTneeds to be set to 1 to workaround a StableHLO bytecode serialization issue.Current workflow
stablehlo.uniform_quantize/dequantizein HLO. The qparams are stored in the custom call config str. The config str can be deserialized to mlir DictAttr directly.stablehlo.uniform_quantize/dequantizeChanges
save_torch_module_as_tf_saved_modelto take GraphModule as well, since PT2E outputs a GraphModule.stablehlo.uniform_quantize/dequantizeconversion, originally authored by @sdasgup3. Another is to workaround a StableHLO bytecode serialization issue mentioned above. Both won't be needed if HLO qdtype representation is added.stablehlo.uniform_quantize/dequantizeFuture Work
cc @sdasgup3 @GleasonK @paulinesho