-
-
Notifications
You must be signed in to change notification settings - Fork 56.5k
Closed
Labels
Milestone
Description
- OpenCV => 4.2.0
- Operating System / Platform => Windows 10 64 Bit
- Compiler => Visual Studio 2017
Detailed description
I am using opencv to load the model exported by python, but it throw exception
OpenCV(4.2.0) C:\build\master_winpack-build-win64-vc15\opencv\modules\dnn\src\layers\eltwise_layer.cpp:209: error: (-215:Assertion failed) inputs[0][j] == inputs[i][j] in function 'cv::dnn::EltwiseLayerImpl::getMemoryShapes'
Steps to reproduce
- Download model from here, it is one of the model from pytorchreid
- Install pytorch as their instructions, remember to change
conda create --name torchreid python=3.7toconda create --name torchreid python=3.6since onnx do not support python 3.7. - Install onnx and update onnx
- Convert it to onnx model by following codes
import torchreid
torchreid.models.show_avai_models()
model = torchreid.models.build_model(name='osnet_ain_x1_0', num_classes=1041)
torchreid.utils.load_pretrained_weights(model, "osnet_ain_x1_0_msmt17_256x128_amsgrad_ep50_lr0.0015_coslr_b64_fb10_softmax_labsmth_flip_jitter.pth")
model.eval()
from torch.autograd import Variable
import torch
import onnx
# An example input you would normally provide to your model's forward() method.
input = torch.ones(1, 3, 256, 128)
raw_output = model(input)
torch.onnx.export(model, input, 'osnet_ain_x1_0.onnx', verbose=False, export_params=True)
print("-------------------------check model---------------------------------------\n")
try:
onnx_model = onnx.load("osnet_ain_x1_0.onnx")
onnx.checker.check_model(onnx_model)
graph_output = onnx.helper.printable_graph(onnx_model.graph)
with open("graph_output.txt", mode="w") as fout:
fout.write(graph_output)
except:
print("Something went wrong")
import onnxruntime
import numpy as np
ort_session = onnxruntime.InferenceSession("osnet_ain_x1_0.onnx")
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)}
ort_outs = ort_session.run(None, ort_inputs)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(raw_output), ort_outs[0], rtol=1e-03, atol=1e-03)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
- Load the model by opencv4.2.0
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>
#include <iomanip>
using namespace cv;
using namespace std;
int main() try
{
dnn::readNetFromONNX("osnet_ain_x1_0.onnx");
}
catch(exception& e)
{
cout << e.what() << endl;
}
A link of the onnx model translated by pytorch, this way you could skip steps 1~5
The output of the graph
The difference of the predict results after convert to onnx
Mismatched elements: 7 / 512 (1.37%)
Max absolute difference: 0.00070047
Max relative difference: 0.02552379
x: array([[9.402755e-05, 2.179507e+00, 3.834043e-01, 2.136301e-01,
0.000000e+00, 0.000000e+00, 8.341767e-01, 2.012278e+00,
3.984281e-01, 5.461890e-02, 2.032685e-01, 7.745304e-06,...
y: array([[9.402186e-05, 2.179636e+00, 3.834159e-01, 2.135554e-01,
0.000000e+00, 0.000000e+00, 8.342320e-01, 2.012077e+00,
3.984202e-01, 5.461205e-02, 2.032657e-01, 7.745568e-06,...
Not as close as expected, but maybe not a big deal.
Reactions are currently unavailable