-
Notifications
You must be signed in to change notification settings - Fork 390
🐛 [Bug] IndexError encountered when using bmm in FX aten path #1789
Copy link
Copy link
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
When compiling the small model below via the FX aten path, an error is encountered in the compose_bmm lowering pass.
def forward(self, x, y):
out = torch.bmm(x, y)
return outERROR:
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
return do_lower(module, inputs)
File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
processed_module = pass_(module, input, *args, **kwargs)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
lower_result = pm(module)
File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
out = _pass(out)
File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
return fn(gm, input)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 262, in <lambda>
trace_func=lambda module, inputs: aten_tracer.opt_trace(
File "~/TensorRT/py/torch_tensorrt/fx/utils.py", line 136, in function_wrapper
return f(*args, **kwargs)
File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 158, in opt_trace
pr: PassResult = passes(fx_module)
File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 420, in compose_bmm
input_input_n = input_n.all_input_nodes[0]
IndexError: list index out of rangeTo Reproduce
Steps to reproduce the behavior:
- Run the code sample below
import torch
import torch_tensorrt
class Sample(torch.nn.Module):
def __init__(self):
super(Sample, self).__init__()
def forward(self, x, y):
out = torch.bmm(x, y)
return out
def main():
model = Sample().cuda().eval()
input_data = torch.zeros((5, 5, 5), dtype=torch.float, device="cuda:0")
input_data_2 = torch.ones((5, 5, 5), dtype=torch.float, device="cuda:0")
out_torch = model(input_data, input_data_2)
mod = torch_tensorrt.fx.compile(model, [input_data, input_data_2],
lower_precision=torch_tensorrt.fx.utils.LowerPrecision.FP32,
min_acc_module_size=1, is_aten=True)
out_trt = mod(input_data, input_data_2)
print(out_trt)
main()Expected behavior
The model should compile
Environment
- Torch-TensorRT Version (e.g. 1.0.0): ad5e764
- PyTorch Version (e.g. 1.0):
2.1.0.dev20230314+cu117
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working