addresses the case when shape of upsample tensor contains ITensor#3841
addresses the case when shape of upsample tensor contains ITensor#3841
Conversation
peri044
left a comment
There was a problem hiding this comment.
LGTM. I think the functionality of to_trt_shape_tensor is probably available at couple of places (eg: concat iirc) manually. Do you think we could unify all this ?
|
yeah thats true. cat does it for inputs while the above is for shape tensor. Yeah I guess we could unify this. Should |
9403b0f to
3d3a8ee
Compare
| # promote remaining ints to TRT consts before concat | ||
| for i, t in enumerate(trt_tensors): | ||
| if isinstance(t, int): | ||
| const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32)) | ||
| set_layer_name(const, target, f"{name}_static_{i}_const") | ||
| trt_tensors[i] = const.get_output(0) | ||
|
|
||
| concat = ctx.net.add_concatenation(trt_tensors) |
There was a problem hiding this comment.
If trt_tensors have a mix of scalar integers and ITensors of dtype int64, would this work (because you're casting the scalar integers to int32 explicitly) ?
There was a problem hiding this comment.
In the case of shape tensors int will always be int32, so in that case this should work.
Coming to cat case. concat tensors will be either torch.Tensor or TRTTensor. They cannot be int. So I think the above should cover all the cases. Can you think of any other case?
There was a problem hiding this comment.
So my thought is how are we ensuring all trt_tensors have same datatypes explicitly before concatenating here because that will error out ?
This check could either be an assertion check or explicit type promotion of tensors within trt_tensor
|
Embedding bag looks like is failing. Need to look into |
1771071 to
91a2519
Compare
| elif isinstance(cast_dtype, np.dtype): | ||
| final_dtype = _enums.dtype._from(cast_dtype).to(trt.DataType) | ||
| else: | ||
| final_dtype = cast_dtype # already trt.DataType |
There was a problem hiding this comment.
should we also check torch.dtype case ?
| # optional cast | ||
| if cast_dtype and isinstance(t, TRTTensor): | ||
| t = cast_trt_tensor(ctx, t, cast_dtype, f"{name}_cast_{i}") |
There was a problem hiding this comment.
is this necessary if we are also casting at line 69 onwards ?
c8a070c to
1b1dfed
Compare
1b1dfed to
dd05c7e
Compare
dd05c7e to
736aeef
Compare
Addresses #3783