Skip to content

Combining libtorch C++ api with "pure" cuda code #40393

@Niohori

Description

@Niohori

I want to perform on the GPU non-maximum-suppression on the output of a darknet/yolo CNN.
The following testing code works fine :

//Credits: adapted from https://github.com/pprp to test nms parallel algo with at::Tensor as input
#include <stdio.h>
#include <stdlib.h>
#include "string"
//#include <windows.h>//needed for LoadLibrary
#include <cuda_runtime.h>
#include
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include "opencv2/imgproc/imgproc.hpp"
#include "device_launch_parameters.h"

#include "device_functions.h"

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>

#include "torch/torch.h"//->including torch gives problems!!!???

using namespace std;

#define HANDLE_ERROR(ans) { gpuAssert((ans), FILE, LINE); }

inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}

typedef struct
{
double x,y,w,h;
char s[100];
char cls[100];
double cmps;
}box;

device inline float devIoU(float const* const b1, float const* const b2) {
float ai = (float)(b1[2] + 1) * (b1[3] + 1);
float aj = (float)(b2[2] + 1) * (b2[3] + 1);
float x_inter, x2_inter, y_inter, y2_inter;

x_inter = max(b1[0], b2[0]);
y_inter = max(b1[1], b2[1]);

x2_inter = min((b1[0] + b1[2]), (b2[0] + b2[2]));
y2_inter = min((b1[1] + b1[3]), (b2[1] + b2[3]));

float w = (float)max((float)0, x2_inter - x_inter);
float h = (float)max((float)0, y2_inter - y_inter);

float inter = ((w * h) / (ai + aj - w * h));
return inter;

}

global void NMS_GPU(const int n_boxes, const float nms_overlap_thresh,
const float* dev_boxes, bool* d_res) {
unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
//unsigned int xIndex = threadIdx.x;//only 1 block with index 0!
float cur_box[5];
float a_box[5];
cur_box[0] = dev_boxes[xIndex * 5 + 0];
cur_box[1] = dev_boxes[xIndex * 5 + 1];
cur_box[2] = dev_boxes[xIndex * 5 + 2];
cur_box[3] = dev_boxes[xIndex * 5 + 3];
cur_box[4] = dev_boxes[xIndex * 5 + 4];
//__syncthreads();//not necessary as cur_box is not a shared resource
for (int i = 0; i < 19; i++)
{
if (i != xIndex)
{
a_box[0] = dev_boxes[i * 5 + 0];
a_box[1] = dev_boxes[i * 5 + 1];
a_box[2] = dev_boxes[i * 5 + 2];
a_box[3] = dev_boxes[i * 5 + 3];
a_box[4] = dev_boxes[i * 5 + 4];
if (a_box[4] < cur_box[4] )
{
if (devIoU(a_box, cur_box) > nms_overlap_thresh)
{
d_res[i] = false;
}
}
}
}
}

int main()
{
int const threadsPerBlock = sizeof(unsigned long long) * 8;//Bufo : =64(float size)
//LoadLibrary(TEXT("D:\dev\Cpp\dependencies\torchnew\lib\torch_cuda.dll"));
at::DeviceType device_type;

if (at::cuda::is_available()) {
	device_type = at::kCUDA;
}
else {
	device_type = at::kCPU;
	std::cout << "No GPU avalaible, sorry ..." << std::endl;
	return 0;
}
at::Device device(device_type);
std::cout << "Device : " << device_type << std::endl;

const int count = 19;
cv::Mat temp = cv::imread("Cow_45.jpg",1);

bool *h_res =(bool*)malloc(sizeof(bool)*count);//contains the result of the nms algo (cpu context)
for(int i=0; i<count; i++)
{
	h_res[i] = true;
}

box b[count];
b[0].x = 996.000000;b[0].y = 2566.420000;b[0].w = 170.793000;b[0].h=172.580000;
strcpy(b[0].cls,"nose");strcpy(b[0].s,"0.983194");b[0].cmps=0.983194;
b[1].x = 4238.937000;b[1].y = 1594.513000;b[1].w = 160.063000;b[1].h=148.487000;
strcpy(b[1].cls,"eye");strcpy(b[1].s,"0.992166");b[1].cmps=0.992166;
b[2].x = 4656.389000;b[2].y = 2175.186000;b[2].w = 316.180000;b[2].h=221.552000;
strcpy(b[2].cls,"nose");strcpy(b[2].s,"0.994816");b[2].cmps=0.994816;
b[3].x = 4316.000000;b[3].y = 1660.000000;b[3].w = 127.474000;b[3].h=113.452000;
strcpy(b[3].cls,"eye");strcpy(b[3].s,"0.990833");b[3].cmps=0.990833;
b[4].x = 997.013000;b[4].y = 2664.408000;b[4].w = 222.214000;b[4].h=229.068000;
strcpy(b[4].cls,"nose");strcpy(b[4].s,"0.985067");b[4].cmps=0.985067;
b[5].x = 666.069000;b[5].y = 2029.219000;b[5].w = 135.689000;b[5].h=160.833000;
strcpy(b[5].cls,"eye");strcpy(b[5].s,"0.993240");b[5].cmps=0.993240;
b[6].x = 4653.547000;b[6].y = 2324.000000;b[6].w = 338.125000;b[6].h=133.902000;
strcpy(b[6].cls,"nose");strcpy(b[6].s,"0.982858");b[6].cmps=0.982858;
b[7].x = 4476.556000;b[7].y = 2131.557000;b[7].w = 253.402000;b[7].h=273.601000;
strcpy(b[7].cls,"nose");strcpy(b[7].s,"0.959098");b[7].cmps=0.959098;
b[8].x = 754.326000;b[8].y = 2571.066000;b[8].w = 324.674000;b[8].h=161.605000;
strcpy(b[8].cls,"nose");strcpy(b[8].s,"0.993699");b[8].cmps=0.993699;
b[9].x = 729.962000;b[9].y = 2658.741000;b[9].w = 349.038000;b[9].h=192.046000;
strcpy(b[9].cls,"nose");strcpy(b[9].s,"0.986209");b[9].cmps=0.986209;
b[10].x = 1271.863000;b[10].y = 2058.679000;b[10].w = 138.781000;b[10].h=137.553000;
strcpy(b[10].cls,"eye");strcpy(b[10].s,"0.989965");b[10].cmps=0.989965;
b[11].x = 4316.000000;b[11].y = 1601.751000;b[11].w = 134.204000;b[11].h=141.249000;
strcpy(b[11].cls,"eye");strcpy(b[11].s,"0.988307");b[11].cmps=0.988307;
b[12].x = 650.901000;b[12].y = 2032.621000;b[12].w = 91.484000;b[12].h=42.112000;
strcpy(b[12].cls,"eye");strcpy(b[12].s,"0.969982");b[12].cmps=0.969982;
b[13].x = 1328.000000;b[13].y = 2058.692000;b[13].w = 103.849000;b[13].h=136.518000;
strcpy(b[13].cls,"eye");strcpy(b[13].s,"0.987316");b[13].cmps=0.987316;
b[14].x = 214.809000;b[14].y = 1599.809000;b[14].w = 1553.705000;b[14].h=1319.679000;
strcpy(b[14].cls,"head");strcpy(b[14].s,"0.997623");b[14].cmps=0.997623;
b[15].x = 3826.177000;b[15].y = 1072.206000;b[15].w = 1254.063000;b[15].h=1412.903000;
strcpy(b[15].cls,"head");strcpy(b[15].s,"0.997487");b[15].cmps=0.997487;
b[16].x = 729.632000;b[16].y = 2578.523000;b[16].w = 442.495000;b[16].h=302.378000;
strcpy(b[16].cls,"nose");strcpy(b[16].s,"0.960093");b[16].cmps=0.960093;
b[17].x = 655.430000;b[17].y = 2031.151000;b[17].w = 91.570000;b[17].h=148.691000;
strcpy(b[17].cls,"eye");strcpy(b[17].s,"0.993275");b[17].cmps=0.993275;
b[18].x = 4251.712000;b[18].y = 1660.000000;b[18].w = 147.288000;b[18].h=105.309000;
strcpy(b[18].cls,"eye");strcpy(b[18].s,"0.992576");b[18].cmps=0.992576;

//***************************************************************************************************************************************
//copy boxes to a  at::tensor
at::Tensor boxes = at::zeros({ count,5 });

for (int i = 0; i < count; i++) {
	boxes[i][0] = b[i].x;
	boxes[i][1] = b[i].y;
	boxes[i][2] = b[i].w;
	boxes[i][3] = b[i].h;
	boxes[i][4] = b[i].cmps;
}

boxes  = boxes.to(device);
float* boxes_as_array = boxes.data<float>();//convert at::tensor to a flat array

float nms_overlap_thresh = 0.1;

//Comment: this piece of code is apparently necessary to assign the Torch cuda context to the global cuda context
THCState* state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
const int col_blocks = THCCeilDiv(count, threadsPerBlock);
unsigned long long* mask_dev = NULL;
mask_dev = (unsigned long long*) THCudaMalloc(state,count * col_blocks * sizeof(unsigned long long));
//-------------------------------------------------------------------------------------------------------------

bool *d_res;
//port h_res to GPU
HANDLE_ERROR(cudaMalloc((void**)&d_res, count*sizeof(bool)));
HANDLE_ERROR(cudaMemcpy(d_res, h_res,sizeof(bool)*count, cudaMemcpyHostToDevice));

NMS_GPU<<<dim3(1,count,1),count>>>(count,nms_overlap_thresh, boxes_as_array,d_res);

//port d_res to CPU
HANDLE_ERROR(cudaMemcpy(h_res, d_res, sizeof(bool)*count, cudaMemcpyDeviceToHost));

//display result
for(int i =0; i<count ; i++)
{
	if(*(h_res+i) == true)
	{
		//printf("GPU Draw: %d--%d\n",i,*(h_res+i));
		cv::putText(temp,b[i].cls,cv::Point((int)b[i].x,(int)b[i].y-5),cv::FONT_HERSHEY_SIMPLEX,1.7,cv::Scalar(255,255,255),5,8,0);
		cv::putText(temp,b[i].s,cv::Point((int)b[i].x+120,(int)b[i].y-5),cv::FONT_HERSHEY_SIMPLEX,1.7,cv::Scalar(255,255,255),5,8,0);
		cv::rectangle(temp,cv::Point((int)b[i].x,(int)b[i].y),cv::Point((int)b[i].x + (int)b[i].w,(int)b[i].y + (int)b[i].h),cv::Scalar(92.185,194),8,8,0);
	}
}
cv::namedWindow("Window",0);
cv::resizeWindow("Window",1064,800);
cv::imshow("Window",temp);
cv::waitKey(0);
return 0;

}

Problem : When I include "torch/torch.h" (uncomment line 23) - something I need for rest of the project - I get the following error (Visual Studio 2019):

"Error member "torch::jit::detail::ParameterPolicy::all_slots" may not be initialized acudaNMSTEST D:\dev\Cpp\dependencies\torchnew\include\torch\csrc\jit\api\module.h 490 "

I struggled a couple of days to find a solution, but in vain. Any idea what is going on/what I did wrong?

cc @malfet @yf225 @glaringlee @peterjc123 @nbcsm @guyang3532

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: buildBuild system issuesmodule: cppRelated to C++ APImodule: windowsWindows support for PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions