feat: support flatten and reshape via shuffle_layer#2354
feat: support flatten and reshape via shuffle_layer#2354gs-olive merged 11 commits intopytorch:mainfrom
Conversation
|
When I test |
b414928 to
d212e51
Compare
gs-olive
left a comment
There was a problem hiding this comment.
Is there a specific reason - performance or otherwise, why flatten should need a different implementation than reshape, when using static shapes? Specifically, we can comment out the flatten implementation for now, and for any converters needing flatten for static shapes, they can just use a reshape and flatten the dimensions themselves.
As an alternative, @zewenli98, you can add a utility flatten_dims, which will flatten the dimensions of an input tensor into a reshape-usable form, then you can have @bowang007's converter test that utility.
Thanks for the advice! I did this because I noticed there's a |
|
Generally, the focus is to cover as much of this operation set as possible: https://pytorch.org/docs/stable/ir.html#core-aten-ir, though if there are operators that show up which we can directly convert as opposed to lowering, that is certainly a good thing to have. |
gs-olive
left a comment
There was a problem hiding this comment.
I do still think flatten_dims can be a utility which gives the shape to pass to reshape. That way, it can get tested as a utility and not as a converter (see tests/py/dynamo/conversion/test_converter_utils.py). Added a suggestion on syntax.
I tried implementing |
|
@zewenli98 I see - thanks for the details - to clarify, I was intending for def flatten_dims(
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
start_dim: int,
end_dim: int,
) -> Tuple[int]:
shape = input.shape
dim_size = len(shape)
start_dim = get_positive_dim(start_dim, dim_size)
end_dim = get_positive_dim(end_dim, dim_size)
if not isinstance(input, TRTTensor):
input = get_trt_tensor(ctx, input, f"{name}_flatten")
num_elements = 1
for i in range(start_dim, end_dim + 1):
num_elements *= shape[i]
new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])
return new_shapeThen, the user can call |
|
@gs-olive Thanks a lot! This makes more sense. Modified! |
5847be9 to
d48d611
Compare
4fc3a8b to
9846f23
Compare
| self.run_test_with_dynamic_shape( | ||
| TestModule(target_shape), | ||
| input_specs, | ||
| expected_ops={torch.ops.aten.view.default}, |
There was a problem hiding this comment.
This line can be removed, in accordance with the new testing PR
gs-olive
left a comment
There was a problem hiding this comment.
Looks good to me, pending CI pass.
|
Relevant tests pass locally on Torch 2.1.0. Merging to |
|
@bowang007 Please consult this PR for the |
|
@zewenli98 Thanks! Let me update PR accordingly |
Description
Support
flattenandreshapevia shuffle_layerFixes #2214
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: