feat: support aten.index_select converter#2710
Conversation
| index: TRTTensor, | ||
| ) -> TRTTensor: | ||
| # The axis parameter specifies the dimension along which to index. | ||
| gather_layer = ctx.net.add_gather(input, index, axis=dim) |
There was a problem hiding this comment.
dim likely needs to be corrected using get_positive_dim to ensure the value is positive for add_gather
There was a problem hiding this comment.
I have modified it. Thanks!
| ("2d_input_dim_0", (10, 3), 0, (0, 2)), | ||
| ("2d_input_dim_1", (5, 10), 1, (1, 2, 3)), | ||
| ("3d_input_dim_0", (10, 5, 10), 0, (0, 5)), | ||
| ("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)), |
There was a problem hiding this comment.
Add a test case for a negative dim input
There was a problem hiding this comment.
I have added a test case for a negative dim input and verified a test case. Thank you!
| kwargs: Dict[str, Argument], | ||
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| return impl.index.index_select( |
There was a problem hiding this comment.
It seems that the index_select function could be put into select.py
There was a problem hiding this comment.
I moved index_select inside select.py. Thank you!
| elementwise, | ||
| embedding, | ||
| grid, | ||
| index, |
There was a problem hiding this comment.
This can likely be removed - it seems to be causing a circular import error in CI
There was a problem hiding this comment.
Thanks! It seems I overlooked removing an unnecessary import.
Description
New feature to support aten.index_select converter. I also add test case for different dimensions.
Fixes # (#2708)
Type of change
Checklist: