[cuBLAS] Add cublas_gemm_batched and use cublasSetStream to set stream to the current stream in all cublas API calls#423
Conversation
e9d69bb to
4a0007f
Compare
4a0007f to
3d58fbd
Compare
yaoyaoding
left a comment
There was a problem hiding this comment.
Thanks @yudi0201 !
It looks good to me.
src/hidet/runtime/cuda/cuda.cpp
Outdated
| CHECK_CUDA(cudaSetDevice(device)); | ||
| } | ||
|
|
||
| DLL void hidet_cuda_malloc(void **devPtr, size_t size) { |
There was a problem hiding this comment.
Consider directly return the allocated memory address like
DLL void* hidet_cuda_malloc(size_t size) {
...
}There was a problem hiding this comment.
Like hidet_cuda_get_device(...).
69abe16 to
ecd1461
Compare
ecd1461 to
a446270
Compare
src/hidet/runtime/cuda/cublas.cpp
Outdated
| // Allocate device memory | ||
| // first use synchronous versions of malloc and memcpy, later switch to async versions | ||
| if (cur_device_ptr_size != 0 && b > cur_device_ptr_size) { | ||
| hidet_cuda_free((void *)ptr_a_device); |
There was a problem hiding this comment.
Why not hidet_cuda_free_async?
There was a problem hiding this comment.
The following logic is more readable to me, just as a reference.
if(b > cur_device_ptr_size) {
if(cur_device_ptr_size > 0) {
free the three ptrs
}
alloc three ptrs
}
There was a problem hiding this comment.
Thanks for the suggestions! I'll modify these in the next revision.
src/hidet/runtime/cuda/cuda.cpp
Outdated
|
|
||
| DLL void* hidet_cuda_malloc(size_t size) { | ||
| lazy_load_cuda_runtime(); | ||
| void* devPtr = malloc(sizeof(void*)); |
There was a problem hiding this comment.
| void* devPtr = malloc(sizeof(void*)); | |
| void* devPtr; |
We do not need to allocate a memory region.
There was a problem hiding this comment.
Maybe I'm missing something, but wouldn't doing this result in devPtr being created on the stack, and then we'd be returning a local stack variable?
There was a problem hiding this comment.
When we call
cudaMallocAsync(&devPtr, ...)we are passing the address of the pointer to the cuda api function, which will update the pointer value. The pointer is a stack variable and will be valid during we calling the cuda api function. We returns the "value" of the pointer instead of the "address" of the pointer to the callee of hidet_cuda_malloc, this is fine.
src/hidet/runtime/cuda/cuda.cpp
Outdated
|
|
||
| DLL void* hidet_cuda_malloc_async(size_t size, cudaStream_t stream) { | ||
| lazy_load_cuda_runtime(); | ||
| void* devPtr = malloc(sizeof(void*)); |
|
Thanks @yudi0201 ! |
No description provided.