-
-
Notifications
You must be signed in to change notification settings - Fork 56.5k
ONNX: layerId != layer_id.end() in function 'handleNode' #19359
Copy link
Copy link
Closed
Description
System information (version)
- OpenCV => 4.5.1
- Operating System / Platform => Ubuntu 20.04
- Compiler => Python 3.8 and C++
- PyTorch => 1.6 and 1.7
Detailed description
I am trying to convert a PyTorch model to ONNX for usage with OpenCV. The ONNX file is working with onnxruntime but not with OpenCV. Tested on both C++ and Python version of OpenCV. The error is
[ERROR:0] global /tmp/pip-req-build-ms668fyv/opencv/modules/dnn/src/onnx/onnx_importer.cpp (1876) handleNode DNN/ONNX: ERROR during processing node with 2 inputs and 1 outputs: [Div]:(105)
Traceback (most recent call last):
File "load.py", line 6, in <module>
net = cv.dnn.readNet('UNet.onnx')
cv2.error: OpenCV(4.5.1) /tmp/pip-req-build-ms668fyv/opencv/modules/dnn/src/onnx/onnx_importer.cpp:1887: error: (-2:Unspecified error) in function 'handleNode'
> Node [Div]:(105) parse error: OpenCV(4.5.1) /tmp/pip-req-build-ms668fyv/opencv/modules/dnn/src/onnx/onnx_importer.cpp:1193: error: (-215:Assertion failed) layerId != layer_id.end() in function 'handleNode'
The error goes away when changing the num_layers to 1. Changing bilinear's values does not seem to have any effect. Code for generating the .onnx file is below.
Steps to reproduce
Code for Unet. Taken from Link
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
"""
Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation
Link - https://arxiv.org/abs/1505.04597
"""
def __init__(
self,
num_classes: int = 19,
num_layers: int = 2,
features_start: int = 64,
bilinear: bool = True,
):
"""
Args:
num_classes: Number of output classes required (default 19 for KITTI dataset)
num_layers: Number of layers in each side of U-net
features_start: Number of features in first layer
bilinear: Whether to use bilinear interpolation or transposed convolutions for upsampling.
"""
super().__init__()
self.num_layers = num_layers
layers = [DoubleConv(3, features_start)]
feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2
for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2, bilinear))
feats //= 2
layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))
self.layers = nn.ModuleList(layers)
def forward(self, x):
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1:self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers:-1]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return self.layers[-1](xi[-1])
class DoubleConv(nn.Module):
"""
Double Convolution and BN and ReLU
(3x3 conv -> BN -> ReLU) ** 2
"""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.net(x)
class Down(nn.Module):
"""
Combination of MaxPool2d and DoubleConv in series
"""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
DoubleConv(in_ch, out_ch)
)
def forward(self, x):
return self.net(x)
class Up(nn.Module):
"""
Upsampling (by either bilinear interpolation or transpose convolutions)
followed by concatenation of feature map from contracting path,
followed by double 3x3 convolution.
"""
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
super().__init__()
self.upsample = None
if bilinear:
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
)
else:
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.upsample(x1)
# Pad x1 to the size of x2
diff_h = x2.shape[2] - x1.shape[2]
diff_w = x2.shape[3] - x1.shape[3]
x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
# Concatenate along the channels axis
x = torch.cat([x2, x1], dim=1)
return self.conv(x)Code used for exporting to ONNX
import torch
from unet import UNet
## Taken from PyTorch ONNX Samples
def main():
model = UNet()
x = torch.randn(1,3, 256, 256, requires_grad=True)
torch_out = model(x)
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
"UNet.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=11)
if __name__ == "__main__":
main()Code used for importing in OpenCV
import numpy as np
import cv2 as cv
print(cv.__version__)
net = cv.dnn.readNet('UNet.onnx')Code used for onnxruntime
import onnxruntime
import numpy as np
ort_session = onnxruntime.InferenceSession("UNet.onnx")
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1,3,256, 256).astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
print(ort_outs[0].shape)Issue submission checklist
- I report the issue, it's not a question
- I checked the problem with documentation, FAQ, open issues,
forum.opencv.org, Stack Overflow, etc and have not found solution - I updated to latest OpenCV version and the issue is still there
- There is reproducer code and related data files: videos, images, onnx, etc
Reactions are currently unavailable