Skip to content

Added DaSiamRPN tracker#16554

Merged
alalek merged 28 commits intoopencv:3.4from
ieliz:tracker
Mar 18, 2020
Merged

Added DaSiamRPN tracker#16554
alalek merged 28 commits intoopencv:3.4from
ieliz:tracker

Conversation

@ieliz
Copy link
Copy Markdown
Contributor

@ieliz ieliz commented Feb 11, 2020

Reopen: #16525

WIP
force_builders=Docs

Code to generate ONNX models
import torch, io, onnx
import numpy as np
import google.protobuf.text_format
import torch.nn as nn
import torch.nn.functional as F
import cv2 as cv
from torch.autograd import Variable
from os.path import realpath, dirname, join

class Kernel_r1(nn.Module):
    def __init__(self, size = 2):
        super(Kernel_r1, self).__init__()
        self.conv_r1 = nn.Conv2d(256, 5120, 3)

    def forward(self, features):
        return self.conv_r1(features)

class Kernel_cls1(nn.Module):
    def __init__(self, size = 2):
        super(Kernel_cls1, self).__init__()
        self.conv_cls1 = nn.Conv2d(256, 2560, 3)

    def forward(self, features):
        return self.conv_cls1(features)

class Tracker(nn.Module):
    def __init__(self, size = 2, feature_out = 512, anchor = 5):
        super(Tracker, self).__init__()
        configs = [3, 96, 256, 384, 384, 256]
        configs = list(map(lambda z: 3 if z == 3 else z * size, configs))
        feat_in = configs[-1]
        self.featureExtract = nn.Sequential(
            nn.Conv2d(configs[0], configs[1], kernel_size = 11, stride = 2),
            nn.BatchNorm2d(configs[1]),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            nn.ReLU(inplace = True),
            nn.Conv2d(configs[1], configs[2], kernel_size = 5),
            nn.BatchNorm2d(configs[2]),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            nn.ReLU(inplace = True),
            nn.Conv2d(configs[2], configs[3], kernel_size = 3),
            nn.BatchNorm2d(configs[3]),
            nn.ReLU(inplace = True),
            nn.Conv2d(configs[3], configs[4], kernel_size = 3),
            nn.BatchNorm2d(configs[4]),
            nn.ReLU(inplace = True),
            nn.Conv2d(configs[4], configs[5], kernel_size = 3),
            nn.BatchNorm2d(configs[5]),
        )
        self.anchor = anchor
        self.feature_out = feature_out

        #For converting large layers(>64MB) to ONNX via torch.onnx.export() use environment variable
        #export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
        self.conv_r1 = nn.Conv2d(feat_in, feature_out*4*anchor, 3)
        self.conv_r2 = nn.Conv2d(feat_in, feature_out, 3)
        self.conv_cls1 = nn.Conv2d(feat_in, feature_out*2*anchor, 3)
        self.conv_cls2 = nn.Conv2d(feat_in, feature_out, 3)
        self.regress_adjust = nn.Conv2d(4*anchor, 4*anchor, 1)#20, 20, 1
        self.r1_kernel = []
        self.cls1_kernel = []
        self.new_layer_1 = nn.Conv2d(256, 20, 4, bias = False)
        self.new_layer_2 = nn.Conv2d(256, 10, 4, bias = False)

    def forward(self, x):
        y = self.featureExtract(x)
        return self.regress_adjust(self.new_layer_1(self.conv_r2(y))), self.new_layer_2(self.conv_cls2(y))

    def temple(self, z):
        z_f = self.featureExtract(z)
        kernel_size = self.conv_r1(z_f).data.size()[-1]
        self.r1_kernel = self.conv_r1(z_f).view(self.anchor * 4, self.feature_out, kernel_size, kernel_size)
        self.cls1_kernel = self.conv_cls1(z_f).view(self.anchor * 2, self.feature_out, kernel_size, kernel_size)

class Feature_extractor(nn.Module):
    def __init__(self, size = 2, feature_out = 512, anchor = 5):
        super(Feature_extractor, self).__init__()
        configs = [3, 96, 256, 384, 384, 256]
        configs = list(map(lambda z: 3 if z == 3 else z * size, configs))
        feat_in = configs[-1]
        self.featureExtract = nn.Sequential(
            nn.Conv2d(configs[0], configs[1], kernel_size = 11, stride = 2),
            nn.BatchNorm2d(configs[1]),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            nn.ReLU(inplace = True),
            nn.Conv2d(configs[1], configs[2], kernel_size = 5),
            nn.BatchNorm2d(configs[2]),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            nn.ReLU(inplace = True),
            nn.Conv2d(configs[2], configs[3], kernel_size = 3),
            nn.BatchNorm2d(configs[3]),
            nn.ReLU(inplace = True),
            nn.Conv2d(configs[3], configs[4], kernel_size = 3),
            nn.BatchNorm2d(configs[4]),
            nn.ReLU(inplace = True),
            nn.Conv2d(configs[4], configs[5], kernel_size = 3),
            nn.BatchNorm2d(configs[5]),
        )

    def forward(self, x):
        y = self.featureExtract(x)
        return y

def assertExpected(s):
    if not (isinstance(s, str) or (sys.version_info[0] == 2 and isinstance(s, unicode))):
        raise TypeError("assertExpected is strings only")

def assertONNXExpected(binary_pb):
    model_def = onnx.ModelProto.FromString(binary_pb)
    onnx.checker.check_model(model_def)
    onnx.helper.strip_doc_string(model_def)
    assertExpected(google.protobuf.text_format.MessageToString(model_def, float_format = '.15g'))
    return model_def

def export_to_string(model, inputs):
    f = io.BytesIO()
    with torch.no_grad():
        torch.onnx.export(model, inputs, f)
    return f.getvalue()

def save_data_and_model(inputs, model, model_files):
    model.eval()
    onnx_model_pb = export_to_string(model, inputs)
    model_def = assertONNXExpected(onnx_model_pb)
    with open(model_files, 'wb') as file:
        file.write(model_def.SerializeToString())

#load frame for channel through network as input for convertation
frame = cv.imread("Full path to input image")

#convertation of tracker
tracker_frame = np.transpose(frame, (2, 0, 1))
tracker_frame = torch.from_numpy(tracker_frame)
tracker_frame = Variable(tracker_frame.unsqueeze(0)).float()
tracker_model = Tracker(size = 1, feature_out = 256)
tracker_path = join(realpath(dirname(__file__)), "Path to model")
load_tracker_path = torch.load(tracker_path, map_location = torch.device("cpu"))
load_tracker_path["new_layer_1.weight"] = Variable(torch.randn(20, 256, 4, 4))
load_tracker_path["new_layer_2.weight"] = Variable(torch.randn(10, 256, 4, 4))
tracker_model.load_state_dict(load_tracker_path)
save_data_and_model(tracker_frame, tracker_model, "Path for saving model file")

#convertation of feature extractor
extractor_frame = torch.from_numpy(frame)
extractor_frame = extractor_frame.type("torch.DoubleTensor")
extractor_frame = extractor_frame.view(1, 3, 360, 480)
extractor_model = Feature_extractor(size = 1, feature_out = 256)
features = extractor_model(extractor_frame.float())

#convertation of kernels
#kernel_r1
kernel_r1_model = Kernel_r1()
path = join(realpath(dirname(__file__)), "Path to model")
load_kernel_r1_path = torch.load(path, map_location  = torch.device("cpu"))
kernel_r1_model_dict = kernel_r1_model.state_dict()
load_kernel_r1_path = {k: v for k, v in load_kernel_r1_path.items() if k in kernel_r1_model_dict}
kernel_r1_model_dict.update(load_kernel_r1_path)
kernel_r1_model.load_state_dict(load_kernel_r1_path)
save_data_and_model(features, kernel_r1_model, "Path for saving model file")

#kernel_cls1
kernel_cls1_model = Kernel_cls1()
path = join(realpath(dirname(__file__)), "Path to model")
load_kernel_cls1_path = torch.load(path, map_location  = torch.device("cpu"))
kernel_cls1_model_dict = kernel_cls1_model.state_dict()
load_kernel_cls1_path = {k: v for k, v in load_kernel_cls1_path.items() if k in kernel_cls1_model_dict}
kernel_cls1_model_dict.update(load_kernel_cls1_path)
kernel_cls1_model.load_state_dict(load_kernel_cls1_path)
save_data_and_model(features, kernel_cls1_model, "Path for saving model file")

@dkurt dkurt self-assigned this Mar 15, 2020
Copy link
Copy Markdown
Member

@dkurt dkurt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Well done!

@l-bat
Copy link
Copy Markdown
Contributor

l-bat commented Mar 18, 2020

👍 Thanks!

@alalek alalek merged commit 221ddec into opencv:3.4 Mar 18, 2020
@alalek alalek mentioned this pull request Mar 20, 2020
@Olalaye
Copy link
Copy Markdown

Olalaye commented Jun 30, 2020

@l-bat
Copy link
Copy Markdown
Contributor

l-bat commented Jun 30, 2020

@Olalaye why can't you get models? All links are valid.

@Olalaye
Copy link
Copy Markdown

Olalaye commented Jul 1, 2020

@l-bat Sorry, it's my Internet problem.It can be used normally now. thanks

@ieliz ieliz deleted the tracker branch August 12, 2020 11:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants