Conversation
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-13 16:23:31.589620+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-13 16:26:06.957081+00:00
@@ -178,12 +178,11 @@
rhs_val = cast_trt_tensor(ctx, rhs_val, trt.float32, name)
return [lhs_val, rhs_val]
def broadcastable(
- a: Union[TRTTensor, np.ndarray],
- b: Union[TRTTensor, np.ndarray]
+ a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray]
) -> bool:
"Check if two tensors are broadcastable according to torch rules"
a_shape = tuple(a.shape)
b_shape = tuple(b.shape)
| rank = len(input_shape) | ||
| adv_indx_count = len(adv_indx_indices) | ||
| dim_tensor_list = [] | ||
| dim_list = [] |
There was a problem hiding this comment.
If this is unused, dim_list = [] can be removed
gs-olive
left a comment
There was a problem hiding this comment.
Pending the above comment, it looks good to me!
gs-olive
left a comment
There was a problem hiding this comment.
Works well on the SD usecase, and tests pass locally - could a test case be added specifically for this addition, which tests the functionality when all inputs are constants?
apbose
left a comment
There was a problem hiding this comment.
Hi @gs-olive. Regarding adding of the test for this feature, the earlier tests were all numpy indices, where is_numpy is set as True. So those tests should be good for this. Is there any way so that tests can be added so that the indices are ITensors?
|
Yes, there is. You would just need to pass the indices as inputs to the |
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-10-30 19:32:26.833431+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-10-30 19:35:28.665002+00:00
@@ -24,11 +24,11 @@
input = [torch.randn(2, 2)]
self.run_test(
TestModule(),
input,
)
-
+
def test_index_zero_two_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0]
out = torch.ops.aten.index.Tensor(x, indices)
@@ -56,25 +56,22 @@
input = [torch.randn(2, 2, 2)]
self.run_test(
TestModule(),
input,
)
-
+
def test_index_zero_index_three_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0, None]
out = torch.ops.aten.index.Tensor(x, indices)
return out
input = torch.randn(2, 2, 2)
index0 = torch.randint(0, 1, (1, 1))
index0 = index0.to(torch.int32)
- self.run_test(
- TestModule(),
- [input, index0]
- )
+ self.run_test(TestModule(), [input, index0])
def test_index_zero_index_one_index_two_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
self.index0 = torch.randint(0, 1, (1, 1))|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from harness import DispatchTestCase |
There was a problem hiding this comment.
Requires switch back to .harness to pass CI
There was a problem hiding this comment.
Ok corrected, wil wait for CI then will merge.
This PR addresses the changes for the cases where indices are list of numpy arrays or numpy for aten::index. #2394