Implement aten::index | feat(torchlib) (#862)#883
Conversation
[ghstack-poisoned]
Remove the duplicate check_model message when displaying [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbao@microsoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: a991236 Pull Request resolved: #883
Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 8755417 Pull Request resolved: #883
Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 078cf56 Pull Request resolved: #883 Signed-off-by: Justin Chu <justinchu@microsoft.com>
Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: e641533 Pull Request resolved: #883 Signed-off-by: Justin Chu <justinchu@microsoft.com>
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 194baee Pull Request resolved: #883 Signed-off-by: Justin Chu <justinchu@microsoft.com>
BowenBao
left a comment
There was a problem hiding this comment.
🎉 LGTM. Do you plan to add bool mask indices support in this or another PR?
| sequence_input.append(input) | ||
| ort_inputs[input_name] = subarg | ||
| else: | ||
| sequence_input.append(subarg) |
There was a problem hiding this comment.
Let's put comment here explaining why this is needed in case we forget?
There was a problem hiding this comment.
Should we consider elif None? It's easier to catch what we are not expecting.
There was a problem hiding this comment.
Good point. I wonder if there are things we don't expect that can sneak in? Nested lists?
There was a problem hiding this comment.
No idea. Just in case.
I plan to do it in another PR to keep this one manageable. |
Codecov Report
@@ Coverage Diff @@
## main #883 +/- ##
==========================================
- Coverage 76.59% 76.51% -0.09%
==========================================
Files 112 112
Lines 13408 13408
Branches 1348 1348
==========================================
- Hits 10270 10259 -11
- Misses 2801 2812 +11
Partials 337 337 |
) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #883 * #882 * __->__ #881 The torchscript ONNX graph generator creates numeric value names by default (`0`, `1`). These are not legal ONNX tensor names, since ONNX requires the names to be valid C variable names. This change updates the names by prepending a prefix `_val_` or `_const_` to make them valid ONNX names. It also improves readability by making the names less likely to be confused with shape values. I decided to use the `_` prefix to reduce the chance of name collision with FX names. After: ``` < ir_version: 8, opset_import: ["" : 18], producer_name: "pytorch", producer_version: "2.1.0" > torch_jit (float[5,5,5,5] input_0, int64[2] input_1_3) => (float[5,5,5,2] _val_10) { _val_2 = Transpose <perm = [0, 1, 2, 3]> (input_0) _val_3 = Max (input_1_3) _val_4 = Shape <start = 0> (_val_3) _val_5 = Expand (input_1_3, _val_4) _const_6 = Constant <value = int64 {-1}> () _val_7 = Unsqueeze (_val_5, _const_6) _val_8 = Concat <axis = -1> (_val_7) _val_9 = GatherND <batch_dims = 0> (_val_2, _val_8) _val_10 = Transpose <perm = [0, 1, 2, 3]> (_val_9) } ``` Before: ``` < ir_version: 8, opset_import: ["" : 18], producer_name: "pytorch", producer_version: "2.1.0" > torch_jit (float[5,5,5,5] input_0, int64[2] input_1_3) => (float[5,5,5,2] 10) { 2 = Transpose <perm = [0, 1, 2, 3]> (input_0) 3 = Max (input_1_3) 4 = Shape <start = 0> (3) 5 = Expand (input_1_3, 4) 6 = Constant <value = int64 {-1}> () 7 = Unsqueeze (5, 6) 8 = Concat <axis = -1> (7) 9 = GatherND <batch_dims = 0> (2, 8) 10 = Transpose <perm = [0, 1, 2, 3]> (9) } ```
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #883 * __->__ #882 * #881 Remove the duplicate check_model message when displaying
…)" --- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 4b050c3 Pull Request resolved: #883 Signed-off-by: Justin Chu <justinchu@microsoft.com>
…)" --- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: b22c3f8 Pull Request resolved: #883 Signed-off-by: Justin Chu <justinchu@microsoft.com>
…)" --- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 1c8f0e2 Pull Request resolved: #883 Signed-off-by: Justin Chu <justinchu@microsoft.com>
| index_1d = common_methods_invocations.index_variable(2, s, device=device) | ||
| index_2d = common_methods_invocations.index_variable((s + 1, 2), s, device=device) | ||
| index_3d = common_methods_invocations.index_variable((s + 2, s + 1, 2), s, device=device) | ||
| test_args = [ |
There was a problem hiding this comment.
would itertools.{product,permutation,combination} help in ensuring all combinations are covered and make the code shorter?
There was a problem hiding this comment.
Good point. Let me try that
There was a problem hiding this comment.
So I tested with itertools but realized some combinations are invalid so we cannot enumerate them all using itertools. For the sake of clarity I propose that we keep the current explicit tests. I also added more test cases and comments
| def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]: | ||
| """Converts kwargs to be compatible with ONNX Runtime. | ||
|
|
||
| ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8. |
There was a problem hiding this comment.
for my knowledge, does ort support torch.bool now? or was this docstring outdated already?
There was a problem hiding this comment.
I actually don't know if ORT supports bool (it should?), but I think this message was a mistake by copilot because we don't actually have this conversion logic as code. If we see issues with ORT I will make adjustments.
| index_1d = common_methods_invocations.index_variable(2, s, device=device) | ||
| index_2d = common_methods_invocations.index_variable((s + 1, 2), s, device=device) | ||
| index_3d = common_methods_invocations.index_variable((s + 2, s + 1, 2), s, device=device) | ||
| test_args = [ |
There was a problem hiding this comment.
could itertools.product (or friends) with length 1 to 4 could shorten this listen and ensure no combination is left out?
There was a problem hiding this comment.
As above. Turns out some combinations are invalid to torch and it may be better to specify explicitly.
…)" --- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 4e36b48 Pull Request resolved: #883 Signed-off-by: Justin Chu <justinchu@microsoft.com>
Stack from ghstack (oldest at bottom):
This change implements the logic for
aten::indexand adds tests for different nd index combinations and permutations.Understanding
aten::indexFor
arg0with shape[7, 3, 4, 5, 6]The indexing operation
arg0[0, :, 1:2, tensor([[4,5]])]will be translated toHere,
indices = [None, None, arg1]is equivalent toindices = [None, None, arg1, None]arg0[0, :, 1:2, tensor([[4,5]])]is equivalent toarg0[0, :, 1:2, tensor([[4,5]]), :]None in
indicesare like fillers for dimensions that cannot be removed in the process.Gather op reference
Co-authored-by: BowenBao bowbao@microsoft.com