Conversation
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_rand_aten.py 2024-01-16 10:05:15.387018+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_rand_aten.py 2024-01-16 10:07:09.239459+00:00
@@ -73,17 +73,18 @@
def forward(self):
return self.rand_op(self.size)
grid_model = TestModule(op, shape_or_input)
- #cannot use self.run_test() since it expects input in form of tensor
-
- #self.run_test(grid_model, None)
+ # cannot use self.run_test() since it expects input in form of tensor
+
+ # self.run_test(grid_model, None)
fx_graph = torch.fx.symbolic_trace(grid_model)
torch._dynamo.reset()
- optimized_model = torch_tensorrt.compile(fx_graph,
+ optimized_model = torch_tensorrt.compile(
+ fx_graph,
"torch_compile",
None,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,56cbaaf to
c31e6a8
Compare
| return False | ||
|
|
||
|
|
||
| @dynamo_tensorrt_converter(torch.ops.aten.rand.default) |
There was a problem hiding this comment.
Ensure the validator is referenced in this decorator (capability_validator=rand_validator)
There was a problem hiding this comment.
Same for the decorators below
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| device = kwargs.get("device", None) | ||
| return np.random.randn(*args).to(device=device) |
| if not isinstance(input, int): | ||
| raise RuntimeError(f"The input must be an integer") |
There was a problem hiding this comment.
This should be in the validator, since converters should not throw errors
| input = args[0] | ||
| if not isinstance(input, int): | ||
| raise RuntimeError(f"The input must be an integer") | ||
| return np.random.randperm(*args).to(device=device) |
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| device = kwargs.get("device", None) | ||
| return np.random.rand(*args).to(device=device) |
There was a problem hiding this comment.
As above, I think you would need to unpack the integers with something like np.random.rand(*args[0]). Additionally, .to would not have effect in Numpy
There was a problem hiding this comment.
Thanks for the review. I think np.random.rand(*args) should also work, since it passes locally. Let me cross check.
9c1569e to
13a6f94
Compare
| if not isinstance(input, int): | ||
| raise RuntimeError(f"The input must be an integer") |
f3436bd to
7b82831
Compare
gs-olive
left a comment
There was a problem hiding this comment.
Added a small comment, otherwise looks good
| kwargs: Dict[str, Argument], | ||
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| device = kwargs.get("device", None) |
There was a problem hiding this comment.
Can be removed since it is unused
| kwargs: Dict[str, Argument], | ||
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| device = kwargs.get("device", None) |
| kwargs: Dict[str, Argument], | ||
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| device = kwargs.get("device", None) |
| kwargs: Dict[str, Argument], | ||
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| return np.random.randn(*args) |
There was a problem hiding this comment.
This fails on my machine, since np.random.randn does not accept tuples, and args is a tuple containing a tuple:
>>> args = ((1, 2, 3,),)
>>> np.random.randn(*args)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "mtrand.pyx", line 1270, in numpy.random.mtrand.RandomState.randn
File "mtrand.pyx", line 1431, in numpy.random.mtrand.RandomState.standard_normal
File "_common.pyx", line 636, in numpy.random._common.cont
TypeError: 'tuple' object cannot be interpreted as an integerThere was a problem hiding this comment.
@gs-olive Thanks for pointing this out.
I think this got missed in the tests, since in the tests the input is None, and the compilation is not invoked.
I changedthe above to have np.random.rand(*args[0]). Also in the test I modified it to have an input (hacky way, I can as well use the harness.py with the below model)
def test_rand(self, name, op, shape_or_input):
class TestModule(nn.Module):
def __init__(self, rand_op, size):
super().__init__()
self.rand_op = rand_op
self.size = size
def forward(self, x):
b = x + 1
self.size[0] = b
return self.rand_op(self.size)
But right now I am running into segmentation fault. I am looking into this further.
| kwargs: Dict[str, Argument], | ||
| name: str, | ||
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| return np.random.rand(*args) |
There was a problem hiding this comment.
Same error as above on my machine
961fd81 to
e099dbb
Compare
| def run_test_comparator( | ||
| self, | ||
| mod, | ||
| inputs, | ||
| expected_ops, | ||
| comparators: List[Tuple[Callable, List]], | ||
| precision=torch.float32, | ||
| output_dtypes=None, | ||
| use_dynamo_tracer=False, | ||
| enable_passes=False, | ||
| ): |
There was a problem hiding this comment.
Does this only test the shapes of the results? If so, can the name be updated to run_test_compare_shapes_only or something similar?
There was a problem hiding this comment.
In this case it is checking the shape of the Tensors. But it is supposed to be a generic function where the callable can check other attribute of the Tensor. For example, in the case of rand and randn we compare the data type as well. So I think the name should be run_test_comparator but you can let me know if you think otherwise.
There was a problem hiding this comment.
Ok that makes sense, I think the name is reasonable then. If possible, maybe run_test_compare_tensor_attributes_only just to delineate that the actual values in the tensors are ignored, but it is also okay as is
There was a problem hiding this comment.
Ok makes sense, I will rename it to the above then.
b2bd47a to
c3c768c
Compare
c3c768c to
2ee77fa
Compare
7d77bd2 to
3710b86
Compare
|
@apbose Can you open a cherry pick PR to release/2.3 ? |
|
Sure I will do that today @peri044 |
Covers the converters- aten.rand, aten.randn, aten.randperm