fix: Properly cast intermediate Int8 tensors to TensorRT Engines in Fallback#1549
Merged
peri044 merged 2 commits intopytorch:masterfrom Dec 22, 2022
Merged
fix: Properly cast intermediate Int8 tensors to TensorRT Engines in Fallback#1549peri044 merged 2 commits intopytorch:masterfrom
peri044 merged 2 commits intopytorch:masterfrom
Conversation
peri044
reviewed
Dec 19, 2022
- Fix compilation error for GPT-2 model arising from Byte-type inputs fed into TensorRT Engine - Update translation dictionary between Torch and TensorRT types to include `at::kByte` - Add field to PartitioningInfo specifying whether to cast Int8 inputs to TensorRT Engines to Int, to avoid error arising from Int8 inputs being fed into non-quantized engines - Add automatic detection of quantized/calibrated models and disable Int8 => Int32 casting in those cases - Fix bug where LoweringInfo target device was not being updated for Python API - Allow `castNode` to force creation of a new node and avoid searching for an existing one to convert - Add test to ensure cast is inserted in the Torch engine preceding a TensorRT engine, when the Byte tensor is an output of the Torch engine
ee69829 to
a4c2d60
Compare
peri044
reviewed
Dec 22, 2022
core/partitioning/shape_analysis.cpp
Outdated
Comment on lines
233
to
245
| if (partitioning_info.truncate_long_and_double) { | ||
| for (size_t i = 0; i < seg_block.inputs().size(); ++i) { | ||
| if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { | ||
| auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]]; | ||
| at::ScalarType t = cur_ivalue.toTensor().scalar_type(); | ||
| if (t == at::kLong) { | ||
| // we add a cast operation to cast the type to Int64 | ||
| auto cast_node = createCastNode(seg_block, i, true, target_device); | ||
| seg_block.g()->prependNode(cast_node); | ||
| seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); | ||
| } | ||
| } | ||
| } |
Collaborator
There was a problem hiding this comment.
Is this just linter formatting changes?
Contributor
Author
There was a problem hiding this comment.
I manually made the formatting changes to reduce redundancy of if statements, but they should be functionally equivalent to the previous version
peri044
requested changes
Dec 22, 2022
- Address review comments - Improve documentation and logging messages - Restructure casting function to allow for casting of variable data types - Add casting for `at::kByte` segment block inputs as well as segment block outputs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
at::kBytecastNodeto force creation of a new node and avoid searching for an existing one to convertError displayed when passing
Int8inputs to non-quantized TRT Engine:With this PR, GPT-2 now compiles and runs inference successfully.
Fixes #1455
Type of change
Checklist: