Add fx passes to support unbounded dynamism in torch op arg#6653
Merged
Add fx passes to support unbounded dynamism in torch op arg#6653
Conversation
b7924af to
de636ac
Compare
6c237e6 to
4b41659
Compare
added 8 commits
March 20, 2024 02:25
format fix comment for skipped tests cover mul (cherry picked from commit f55abc88ae361e89da675a1aa1e4a19e7a5c762a) cover mul (cherry picked from commit 30abe2be43defc25db8954c525d34f7f3de35292) add missing tests to ci scripts yapf fix scalar type (cherry picked from commit 8526b2091ffafccf6972ecba3c111d1b0869621e) disable addmm test disable mark pattern api in gh ci, due to tf dep enable conv dynamism support addmm enable softmax dynamism update comment for slice add slice support, need converter change update test script take dynamic shape in save model export api verify lowering by adding tfl inference in tests remove debug pritn add assertion of sliced dim in select lowering remove log in conv, remove assertion in select re-enable test add select fx pass add no op slice removal pass add fx passes add tests' support layernorm add vit export scripot fix ep callable enable gelu test add export script support dynamic view with sym dim on dims other than BS add tests for gemma export support unsqueeze support softmax reduction on dynamic dim support unbounded index (unfinished) support dynamic expand add groupnorm add conv1d support, add dynamism (partially) to view add wav2vec2 export script add cumsum test, ne test remove existing tests change from crlf to lf add checks on view move stablehlo test util script remove debugging print add more assertions to fx passes remove test print add docstr to dynamic op make export script more concise
4b41659 to
ce436f8
Compare
qihqi
approved these changes
Mar 20, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add FX passes to support dynamism on torch ops that take symbolic dim size as parameters (e.g. `aten.view(x, [sym_size, -1]). The FX passes groups the generation of symbolic dim size and the torch op into a new XLA op. Grouping them together is necessary to lowering these ops with dynamism, because the operations on symbolic size cannot be traced in LTC (context in #6393). Once the source of the symbolic size and the consuming torch ops are captured in a single XLA op, it becomes feasible to lowered to HLO/StableHLO with dynamism semantics.
The FX passes will run automatically if the exported program has symbolic shape input.
The following ops are fused into XLA ops:
sym_size.int+aten.expand=>xla.dynamic_expandsym_size.int+ (mul) +aten.view=>xla.dynamic_viewSome torch ops are generating
aten.viewwith symbolic dim size during decomposition in upstream PyTorch, or existing lowering logic in torch_xla. FX passes are introduce to handle these ops as well.aten.native_layer_normaten.group_normaten.selectaten.unsqueezeOther changes:
rsubmeanandvartorch_xla/test/stablehlo/utils.pytotorch_xla/utils/stablehlo_test_utils.py, to fix path not found issue when running tests with pytest.Test: