-
Notifications
You must be signed in to change notification settings - Fork 390
🐛 [Bug] global partitioner does not work while compiling with dynamo #3157
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't workingstory: Dynamo Frontend & Partitioningtorch.compile, torch.export, FX graph tracing, graph partitioner, graph breaks, and the Dynamo-to-TRtorch.compile, torch.export, FX graph tracing, graph partitioner, graph breaks, and the Dynamo-to-TR
Description
Hi all! Trying to use global partitioning fails with the dynamo backend, and couldn't pinpoint why (tried various compilation parameters).
How to Reproduce:
System:
Cuda Driver Version: 535.104.12
GPU: Nvidia Tesla T4
Python: 3.11.10
Dependencies (wheels):
https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl
https://download.pytorch.org/whl/cu121/torch_tensorrt-2.4.0%2Bcu121-cp311-cp311-linux_x86_64.whl
Script to reproduce:
import torch
import torch_tensorrt
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(32 * 134 * 134, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
model = SimpleCNN()
def compile_to_tensorrt() -> None:
batch_size, tile_size = 1, 538
model = SimpleCNN().to(dtype = torch.float16, device = torch.device('cuda'))
model.eval()
with torch.no_grad():
inputs = torch.randn(
batch_size, 3, tile_size, tile_size, device="cuda", dtype=torch.float16
)
print("Compiling model...")
_trt_graph_module = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[inputs],
enabled_precisions={torch.float16},
use_fast_partitioner=False,
)
if __name__ == "__main__":
compile_to_tensorrt()Error and TraceBack:
Traceback (most recent call last):
compile_to_tensorrt()
File "reproduce.py", line 37, in compile_to_tensorrt
_trt_graph_module = torch_tensorrt.compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch_tensorrt/_compile.py", line 249, in compile
trt_graph_module = dynamo_compile(
^^^^^^^^^^^^^^^
File "site-packages/torch_tensorrt/dynamo/_compiler.py", line 230, in compile
trt_gm = compile_module(gm, inputs, settings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch_tensorrt/dynamo/_compiler.py", line 365, in compile_module
for node in submodule.graph.nodes
^^^^^^^^^^^^^^^
File "site-packages/torch/nn/modules/module.py", line 1729, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'Module' object has no attribute 'graph'
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstory: Dynamo Frontend & Partitioningtorch.compile, torch.export, FX graph tracing, graph partitioner, graph breaks, and the Dynamo-to-TRtorch.compile, torch.export, FX graph tracing, graph partitioner, graph breaks, and the Dynamo-to-TR