Skip to content

add CropAndResize layer for CUDA backend#16069

Merged
alalek merged 2 commits intoopencv:masterfrom
YashasSamaga:cuda4dnn-crop_and_resize
Dec 9, 2019
Merged

add CropAndResize layer for CUDA backend#16069
alalek merged 2 commits intoopencv:masterfrom
YashasSamaga:cuda4dnn-crop_and_resize

Conversation

@YashasSamaga
Copy link
Copy Markdown
Contributor

@YashasSamaga YashasSamaga commented Dec 5, 2019

This pullrequest changes

  • add CropAndResize support for CUDA backend

Timings (on GTX 1050):

Model Before Now
Inception v2 Faster RCNN 124ms 75ms
Inception v2 Mask RCNN 224ms 121ms
ResNet50 Faster RCNN 232ms 122ms

^ more information can be found here

force_builders=Custom
buildworker:Custom=linux-4
docker_image:Custom=ubuntu-cuda:18.04

@YashasSamaga
Copy link
Copy Markdown
Contributor Author

YashasSamaga commented Dec 5, 2019

Is there a test for CropAndResize layer? I am not able to find one.

Otherwise, the PR is ready and not WIP.

@tompollok
Copy link
Copy Markdown
Contributor

Have you tried MaskRCNN with this patch?

Last time without CUDA CropAndResize you used 7700HQ and GTX 1050 and your results were:

Inception v2 Mask RCNN
OCV CPU Time: 3280ms
CUDA Total Time: 407ms
Relative Error >> Total: 0.787171, Average: 1.94363e-07, Max: 1.65757e-06

@YashasSamaga
Copy link
Copy Markdown
Contributor Author

YashasSamaga commented Dec 5, 2019

@tompollok

Results below are from the same device configuration (OS is Ubuntu 18.04 instead of Windows). The relative error is between OCV CPU output and CUDA output. The target used was DNN_TARGET_CUDA. Further speedup can be obtained on some devices with DNN_TARGET_CUDA_FP16.

Differences between previous setup and current setup:

What Before Now
OS Windows 10 Ubuntu 18.04
CUDA 10.1 10.2
cuDNN 7.6.1 7.6.5

I am not sure why there is a huge difference in the timings without CropAndResize back then and now. I am confused by the timing difference in the OCV CPU too. The compilers are different (MSVC 19.16 vs GCC 7.4) but it shouldn't be that dramatic?

Inception v2 Mask RCNN
OCV CPU Time:   1090ms
IE CPU Time:    583ms
CUDA Total Time: 120.399ms
Relative Error >> Total: 0.63299, Average: 3.12588e-07, Max: 2.26777e-06


ResNet50 Faster RCNN
OCV CPU Time:   916ms
IE CPU Time:    586ms
CUDA Total Time: 133.605ms
Relative Error >> Total: 0.000495682, Average: 7.08118e-07, Max: 2.37544e-05


Inception v2 Faster RCNN
OCV CPU Time:   378ms
IE CPU Time:    251ms
CUDA Total Time: 75.324ms
Relative Error >> Total: 0.000538239, Average: 7.68913e-07, Max: 1.70529e-05

@tompollok
Copy link
Copy Markdown
Contributor

Didnt you say that using CPU fallback results in copying data back and forth from GPU and CPU?
Also there has been some optimization work lately, i dont know if and how that affects the opencv cpu backend. Id say running the latest commit on your windows machine would allow to see if there is actually a big difference on these platforms. Anyways, great job!

Btw: are there still many cpu fallbacks or are these models listed above now running completely on CUDA backend?

Would be kind of interesting to have some debug function that shows the backend forwarding path for a model such that its easy to see which layers are not on the same (cuda) backend

@YashasSamaga
Copy link
Copy Markdown
Contributor Author

YashasSamaga commented Dec 6, 2019

@tompollok

Set OPENCV_LOG_LEVEL environment variable to INFO an you'll be able to see which layers are using fallbacks.

[ INFO:0] global /home/yashas/Desktop/gsoc/opencv/modules/dnn/src/dnn.cpp (2204) initCUDABackend CUDA backend will fallback to the CPU implementation for the layer "_input" of type __NetInputLayer__

[ INFO:0] global /home/yashas/Desktop/gsoc/opencv/modules/dnn/src/dnn.cpp (2204) initCUDABackend CUDA backend will fallback to the CPU implementation for the layer "detection_out" of type DetectionOutput

[ INFO:0] global /home/yashas/Desktop/gsoc/opencv/modules/dnn/src/dnn.cpp (2204) initCUDABackend CUDA backend will fallback to the CPU implementation for the layer "detection_out_final" of type DetectionOutput

The only missing layer now is DetectionOutput (input layer is skipped as it's NOP for the models mentioned in this PR). By the nature of computations involved, I don't think it's worth porting it to GPU fully. It might help to move part of DetectionOutput to GPU and perform the final steps such as NMS on the CPU (like the way it's done for region layer). I'll have to dig through the code and think about it.

Other optimizations are possible:

  1. Fuse ReLU with convolution
  2. Skip concat (by directly writing to the concat output block by its parent layers)
  3. DetectionOutput involves 3 device to host copy before the layer starts working. Two of the inputs are available much before DetectionOutput layer. The device-to-host copy can be started as early as possible for those two inputs. These memory transfers can then overlap with the computation of the layers that follow. So the two D2H transfers can mostly be hidden completely or at least partially.

This is what I get on Windows.

Inception v2 Mask RCNN
OCV CPU Time:   3219ms
CUDA Total Time: 362ms
Relative Error >> Total: 0.790324, Average: 1.95142e-07, Max: 1.47339e-06

@YashasSamaga YashasSamaga changed the title [WIP] add CropAndResize layer for CUDA backend add CropAndResize layer for CUDA backend Dec 6, 2019
@dkurt
Copy link
Copy Markdown
Member

dkurt commented Dec 9, 2019

Is there a test for CropAndResize layer? I am not able to find one.

Otherwise, the PR is ready and not WIP.

@YashasSamaga. thanks for this feature! There is no single layer test but there are tests for Faster R-CNNs with it.

dkurt
dkurt previously approved these changes Dec 9, 2019
Copy link
Copy Markdown
Member

@dkurt dkurt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Thanks!

@dkurt dkurt self-assigned this Dec 9, 2019
@YashasSamaga
Copy link
Copy Markdown
Contributor Author

YashasSamaga commented Dec 9, 2019

@dkurt I just realized after the notification about this PR's approval that maybe the optimizations added in PR16097 to the resize kernel is applicable here too. I would like to profile and check. If it doesn't cause trouble and you're ok with it, I can add a commit to this PR to optimize the kernel. Otherwise, I will make another PR with the optimization.

I am sorry for the last-minute change.

EDIT: got around 3.3x improvement with around 5.5ms reduction in Inception v2 Mask RCNN inference time on GTX 1050. Should I add the commit here or make another PR?

@dkurt dkurt dismissed their stale review December 9, 2019 12:12

for extra commits

@dkurt
Copy link
Copy Markdown
Member

dkurt commented Dec 9, 2019

@YashasSamaga, no problem. Let's push it here.

@YashasSamaga YashasSamaga force-pushed the cuda4dnn-crop_and_resize branch from 04ac019 to ce13070 Compare December 9, 2019 12:48
@alalek alalek merged commit 3fddd3b into opencv:master Dec 9, 2019
@JTzhuang
Copy link
Copy Markdown

@dkurt @YashasSamaga I encountered a strange problem, When I tested with the following code, I found that the inference time of the model is different, time1 = 700ms, time2 = 3ms. And when I run this code on CPU, the inference time = 300ms, It so weired.

`
double start = getTickCount();
net2.forward(outs, outNames);
double time1 = 1000*(getTickCount() - start) / getTickFrequency();
cout << time1 << endl;

vector layersTimings2;
double freq1 = getTickFrequency() / 1000;
double time2 = net2.getPerfProfile(layersTimings2) / freq1;
cout << time2 << endl;
`

@YashasSamaga
Copy link
Copy Markdown
Contributor Author

@JTzhuang The first forward pass has initialization cost. You should ignore it. The subsequent forward passes should be faster.

@JTzhuang
Copy link
Copy Markdown

@YashasSamaga The first forward pass is 800ms, and the subsequent forward are about 700ms.

@YashasSamaga
Copy link
Copy Markdown
Contributor Author

@JTzhuang Please share the model you are using.

@JTzhuang
Copy link
Copy Markdown

@YashasSamaga
https://drive.google.com/open?id=1Wsx4JOKvn6Xn2Rr7m0VX2HTXEvUBvWm3
This is my code and .pb file, you can test it on CPU and GPU.

@YashasSamaga
Copy link
Copy Markdown
Contributor Author

YashasSamaga commented May 22, 2020

@JTzhuang Sorry, I misread your initial post. The CUDA backend doesn't support getPerfProfile. If you need to check layerwise timings, you can use NVIDIA Nsight Systems. It provides nice visualizations and easy-to-use GUI interface.

It might be difficult to identify what operations correspond to which layer. CUDA allows profiling data to be annotated in the code using NVTX. The CUDA backend currently doesn't have any profiler annotation capabilities. If you think this could be a useful feature, please open a feature request issue.

CPU:

CPU output
0.667441[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
653.698
0.198471[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
190.316
0.197608[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
189.356
0.197534[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
188.999
0.19357[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
184.509
0.195504[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
187.31
0.194946[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
186.271
0.197041[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
188.171
0.196134[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
186.883
0.19249[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
184.133
0.19389[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
184.95
0.195315[1.318402e-05, 1.318402e-05, 0.99996042, 1.318402e-05]
186.562

CUDA:

CUDA output
0.493976[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
3.0786
0.522221[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
2.73693
0.52053[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
3.00782
0.52222[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
2.77092
0.525574[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
2.69902
0.523383[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
3.3333
0.521101[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
2.81944
0.528263[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
2.75951
0.528537[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
2.76679
0.529091[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
2.72331
0.54491[1.3184058e-05, 1.3184058e-05, 0.99996042, 1.3184058e-05]
3.07419

@JTzhuang
Copy link
Copy Markdown

@YashasSamaga So you think the inference cost on GPU is right? It's obvious that the inference time on GPU is much slower than CPU.
Is there something wrong with my testing method.

@YashasSamaga
Copy link
Copy Markdown
Contributor Author

YashasSamaga commented May 22, 2020

Without color_input set:

Test Network
[OCV CPU]
        init >> 185.475ms
        inference >> min = 178.937ms, max = 185.759ms, mean = 182.892ms, stddev = 1.91519ms
[CUDA FP32]
        init >> 491.033ms
        inference >> min = 41.773ms, max = 42.021ms, mean = 41.9014ms, stddev = 0.0698771ms

With color_input set:

Test Network
[OCV CPU]
        init >> 192.439ms
        inference >> min = 191.679ms, max = 198.518ms, mean = 194.69ms, stddev = 2.36098ms
[CUDA FP32]
        init >> 486.104ms
        inference >> min = 532.097ms, max = 547.586ms, mean = 538.318ms, stddev = 4.02725ms

I tried to view your model in netron. No layer seems to be using color_input. I need to find to out why giving color_input kills the performance.

Does your model have depthwise convolutions?

@JTzhuang
Copy link
Copy Markdown

@YashasSamaga No. My network has two path. One path is used to extract feature map from color image(color_input ), and another path for height image to extract feature maps in the same way. and resnet101 is used as the backbone in the two path. So there is no depthwise Convolution in my model.

@JTzhuang
Copy link
Copy Markdown

@YashasSamaga it' just like the Siamese Network. I don't think it would get the same output as two input if without color_input set.

@YashasSamaga
Copy link
Copy Markdown
Contributor Author

YashasSamaga commented May 22, 2020

I added getPerfProfile support to the CUDA backend here.

OCV Time: 193ms, getPerfProfile: 183.962ms
CUDA Time: 532ms, getPerfProfile: 42.5451ms
Code
#include <iostream>
#include <chrono>

#include <opencv2/core/cuda.hpp>
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>

int main ()
{
    auto color_input = cv::imread("9-c.png");
	auto height_input = cv::imread("9-h.png");
	auto colo_input_blob = cv::dnn::blobFromImage(color_input, 1 / 255, cv::Size(224, 224), cv::Scalar(), true, false);
	auto height_input_blob = cv::dnn::blobFromImage(height_input, 1 / 255, cv::Size(224, 224), cv::Scalar(), false, false);
    
	auto net_ocv = cv::dnn::readNetFromTensorflow("DP_concat.pb");
    net_ocv.setPreferableBackend(cv::dnn::DNN_BACKEND_OPENCV);
    net_ocv.setPreferableTarget(cv::dnn::DNN_TARGET_CPU);

    {
        net_ocv.setInput(colo_input_blob, "color_input");
        net_ocv.setInput(height_input_blob, "height_input");
        net_ocv.forward("feature_concat/Softmax");
    }

    auto start = std::chrono::steady_clock::now();
    {
        net_ocv.setInput(colo_input_blob, "color_input");
        net_ocv.setInput(height_input_blob, "height_input");
        net_ocv.forward("feature_concat/Softmax");
    }
    auto end = std::chrono::steady_clock::now();
    auto ocv_time = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

    std::vector<double> ocvLayerTimings;
    auto ocv_getPerfProfile = net_ocv.getPerfProfile(ocvLayerTimings) / cv::getTickFrequency() * 1000;
    std::cout << "OCV Time: " << ocv_time << "ms, getPerfProfile: " << ocv_getPerfProfile << "ms" << std::endl;

    auto net_cuda = cv::dnn::readNetFromTensorflow("DP_concat.pb");
    net_cuda.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA);
    net_cuda.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA);
   
    {
        net_cuda.setInput(colo_input_blob, "color_input");
        net_cuda.setInput(height_input_blob, "height_input");
        net_cuda.forward("feature_concat/Softmax");
    }

    start = std::chrono::steady_clock::now();
    {
        net_cuda.setInput(colo_input_blob, "color_input");
        net_cuda.setInput(height_input_blob, "height_input");
        net_cuda.forward("feature_concat/Softmax");
    }
    end = std::chrono::steady_clock::now();
    auto cuda_time = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

    std::vector<double> cudaLayerTimings;
    auto cuda_getPerfProfile = net_cuda.getPerfProfile(cudaLayerTimings) / cv::getTickFrequency() * 1000;
    std::cout << "CUDA Time: " << cuda_time << "ms, getPerfProfile: " << cuda_getPerfProfile << "ms" << std::endl;

    return 0;
}

The first number in each line is the actual time it took for net.forward() to finish executing. The second number in each line is what the layerwise sum of timings gives.

It appears like the CUDA backend is taking 42ms to compute the outputs. The remaining ~500ms is coming from somewhere. This looks like a bug. Please open an issue.


I profiled using Nsight Systems. It looks like the CUDA backend completely reinitialized for the second forward pass. The reinitialization was totally unnecessary. It's as if every pass is like the first forward pass.

I need to dig deeper and check but I think the bug is not in the CUDA backend. It is probably in the general initialization logic.

@JTzhuang
Copy link
Copy Markdown

@YashasSamaga Thank you for your support. waiting for your solution.

@YashasSamaga
Copy link
Copy Markdown
Contributor Author

@JTzhuang Please make an issue so that this bug (if at all) is tracked.

@JTzhuang
Copy link
Copy Markdown

@YashasSamaga I have opened a new issue.
#17358

a-sajjad72 pushed a commit to a-sajjad72/opencv that referenced this pull request Mar 30, 2023
…esize

add CropAndResize layer for CUDA backend

* add CropAndResize layer

* process multiple channels per iteration
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants