fix: Allow full model compilation with collection outputs#1599
fix: Allow full model compilation with collection outputs#1599gs-olive merged 3 commits intopytorch:mainfrom
Conversation
a1dd53c to
655ce22
Compare
ef828ad to
08bb906
Compare
narendasan
left a comment
There was a problem hiding this comment.
Looks mostly fine, just some ux and dev stuff
08bb906 to
17753fc
Compare
|
Thanks for the comments and review @narendasan - I have incorporated the feedback and updated two of the user warnings to compilation-halting errors. One note I wanted to make is that despite the TensorRT/core/partitioning/partitioning.cpp Lines 31 to 48 in 86a998e Any intermediate packing/unpacking is handled by the evaluators and does not cause a graph segmentation, since those nodes are not directly graph outputs. |
f43ef07 to
2b35c09
Compare
| // executed in TRT, regardless of the size of the graph | ||
| if (expect_full_compilation) { | ||
| // If minimum block size is different from the default, the user must have specified it | ||
| if (ctx->settings.min_block_size != 3) { |
There was a problem hiding this comment.
Create an issue to centralize defaults somewhere in the core
There was a problem hiding this comment.
What if a user sets min_block_size =3 as well? Does he still get the warning message?
There was a problem hiding this comment.
No, the user would not get a warning message in that case. We currently don't have a way of knowing whether the user inputs a value or not, since the defaults are not centralized. There is an issue #1644 to address this, but as of now, your statement is correct. Additionally, it is worth noting that prior to this PR, if a user specified min_block_size and require_full_compilation=True, we would still ignore the min_block_size, but without warning.
2b35c09 to
1209225
Compare
narendasan
left a comment
There was a problem hiding this comment.
LGTM, @bowang007 can you take a look?
| // executed in TRT, regardless of the size of the graph | ||
| if (expect_full_compilation) { | ||
| // If minimum block size is different from the default, the user must have specified it | ||
| if (ctx->settings.min_block_size != 3) { |
There was a problem hiding this comment.
What if a user sets min_block_size =3 as well? Does he still get the warning message?
core/compiler.cpp
Outdated
| (!(cfg.lower_info.forced_fallback_modules.size() == 0 && | ||
| cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || | ||
| outputIsCollection || user_requested_long)) || | ||
| requires_collection_handling) { |
There was a problem hiding this comment.
Could this If statement be optimized? Seems like isBlockConvertible and outputIsCollection overlap with require_collection_handling
There was a problem hiding this comment.
I've updated this statement to make the conditions clearer, by removing the ! and distributing it over the conditionals inside. Other than this, the statement cannot be reduced any further since the requires_collection_handling boolean is independent of cfg.partitioning_info.enabled (since we want to partition in this case regardless of require_full_compilation=True)
|
|
||
| // If full compilation is expected, cannot have more than 2 Torch segments | ||
| // (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment | ||
| if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) { |
There was a problem hiding this comment.
Do we have edge cases like 2 torch_segments for inputs/outputs? Does merge_adjacent_segments_of_same_type always merge them into one?
There was a problem hiding this comment.
There should not be a case where multiple Torch segments appear for inputs/outputs, since merge_adjacent_segments_of_same_type addresses this case, as you had mentioned. Since the tensors in question are inputs, it should not arise that segment.do_not_merge() is True, since the only approved operators falling into these segments are for collection construction, and only the prim::If or prim::Loop operators can induce a non-merge situation.
1209225 to
91306c9
Compare
core/compiler.cpp
Outdated
| if ((cfg.partitioning_info.enabled && | ||
| (cfg.lower_info.forced_fallback_modules.size() != 0 || | ||
| cfg.partitioning_info.forced_fallback_operators.size() != 0 || !isBlockConvertible || outputIsCollection || | ||
| user_requested_long)) || | ||
| requires_collection_handling) { |
There was a problem hiding this comment.
Note the updates to the conditional logic to make it
- Update graph-building in compiler to account for case where all operations are supported by Torch-TRT, but the output is a collection. - Enable 'psuedo-partitioning' for nearly-fully-compiled models for which the only non-supported aspect of the model is the format of the output (TRT cannot output complex collections) - Define a small subset of operation schemas which are allowed despite the flag `require_full_compilation`. These operations are packing and unpacking of Tuples/Lists, and some are already used in cases of `require_full_compilation` - Display warnings to users if any portion of the `pseudo-partitioning` is unexpected, for example the model being partitioned ends up in more than 3 segments (maximally - a Torch segment to preprocess collection inputs, a TRT segment to perform model logic, a Torch segment to post-process collection outputs) or if schemas falling outside of the collection subset are encountered in a Torch segment - Add end-to-end test case with minimal reproducing example of a failing model, repaired with the changes to the compiler - Add minor fix to lowering to remediate c++ compiler warning
- Add function to check the equivalence of two collection-based outputs for comparison across Torch-TRT and Torch outputs - Improved test robustness in end-to-end to check for equivalent output schemas in addition to successful compilation
- Add test case to elicit behavior where full compilation is requested but TRT engine size falls below default `min_block_size=3` - Move `min_block_size` condition to narrow scope - Coalesce logic to improve code readability
91306c9 to
00f1a3a
Compare
| // Partitioning is required if: | ||
| // 1. User requested some modules/operators fallback | ||
| // 2. The block (graph) cannot be converted due to operator coverage | ||
| // 3. The output of the graph is a collection | ||
| // 4. The user requested a non-TRT data type input | ||
| auto isPartitioningRequired = | ||
| (isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long); |
There was a problem hiding this comment.
Coalesced partitioning logic for readability
| bool userRequestedFallback(CompileSpec& cfg) { | ||
| return cfg.lower_info.forced_fallback_modules.size() != 0 || | ||
| cfg.partitioning_info.forced_fallback_operators.size() != 0; | ||
| } |
There was a problem hiding this comment.
Added helper function to determine if the user's input specifications imply fallback
Description
require_full_compilation. These operations are packing and unpacking of Tuples/Lists, and some are already used in cases ofrequire_full_compilationpseudo-partitioningis unexpected, for example the model being partitioned ends up in more than 3 segments (maximally - a Torch segment to preprocess collection inputs, a TRT segment to perform model logic, a Torch segment to post-process collection outputs) or if schemas falling outside of the collection subset are encountered in a Torch segmentThis fix was designed to minimally alter the existing phases of model conversion and does not manually flatten/reconstruct complex collection outputs, but instead uses the existing partitioning infrastructure and engine-stitching paradigm to accomplish this.
Fixes #1598
Fixes #1368
Type of change
Checklist: