Skip to content

[Operator] Add the support of using external kernels in hidet#128

Merged
yaoyaoding merged 4 commits intohidet-org:mainfrom
yaoyaoding:external
Mar 6, 2023
Merged

[Operator] Add the support of using external kernels in hidet#128
yaoyaoding merged 4 commits intohidet-org:mainfrom
yaoyaoding:external

Conversation

@yaoyaoding
Copy link
Copy Markdown
Member

This PR allows Hidet to launch kernels in external source (e.g., manually written cuda code, or vendor library like cudnn, cublas) in hidet.

Example

Define the external kernel and build dynamic library

Store the code in naive_matmul.cu:

#include <stdio.h>
#include <cuda.h>

__global__ void naive_matmul_kernel(float* a, float *b, float *c, int M, int N, int K) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    int j = blockIdx.y * blockDim.y + threadIdx.y;
    if (i < M && j < N) {
        float sum = 0.0;
        for (int k = 0; k < K; k++) {
            sum += a[i * K + k] * b[k * N + j];
        }
        c[i * N + j] = sum;
    }
}


extern "C" {

__host__ void naive_matmul(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args) {
    float* __restrict__ a = (float *)args[0];
    float* __restrict__ b = (float *)args[1];
    float* __restrict__ c = (float *)args[2];
    int32_t M = *(int32_t *)args[3];
    int32_t N = *(int32_t *)args[4];
    int32_t K = *(int32_t *)args[5];
    auto block_size = dim3(16, 16);
    auto grid_size = dim3((M + block_size.x - 1) / block_size.x, (N + block_size.y - 1) / block_size.y);
    naive_matmul_kernel<<<grid_size, block_size>>>(a, b, c, M, N, K);
}
}

Then compiles it to naive_matmul.so

$ /usr/local/cuda/bin/nvcc --compiler-options '-fPIC' --shared ./naive_matmul.cu -o ./naive_matmul.so

Define the Task and Operator in Hidet

See our documentation (here) for more information about Task and Operator definition.

import hidet
from hidet.graph.operator import Operator
from hidet.ir.compute import TensorInput, compute, reduce
from hidet.ir.task import Task, task_compiled_func_type
from hidet.runtime import CompiledFunction
from hidet.graph.ops.definitions.utils import input_like


class NaiveMatmulTask(Task):
    def __init__(self, a: TensorInput, b: TensorInput):
        if a.ndim != 2 or b.ndim != 2:
            raise ValueError('Only support matrix multiplication')
        c = compute(
            name='c',
            shape=[a.shape[0], b.shape[1]],
            fcompute=lambda i, j: reduce(
                shape=[a.shape[1]],
                fcompute=lambda k: a[i, k] * b[k, j],
                reduce_type='sum'
            )
        )
        super().__init__(
            name='naive_matmul', inputs=[a, b], outputs=[c],
            arguments=[
                a,
                b,
                c,
                a.shape[0],
                b.shape[1],
                a.shape[1]
            ]
        )

    def build(self, target: str) -> CompiledFunction:
        from hidet.backend import load_lib_func
        return load_lib_func(
            lib_path='./naive_matmul.so',
            func_name='naive_matmul',
            func_type=task_compiled_func_type(self)
        )


class NaiveMatmulOp(Operator):
    def __init__(self, a: hidet.Tensor, b: hidet.Tensor):
        super().__init__(
            inputs=[a, b],
            task=NaiveMatmulTask(input_like(a, 'a'), input_like(b, 'b'))
        )


def naive_matmul(a, b):
    return NaiveMatmulOp(a, b).get_output(0)


def main():
    a = hidet.ones([3, 4]).cuda()
    b = hidet.ones([4, 5]).cuda()
    c = naive_matmul(a, b)
    hidet.cuda.synchronize()
    print(a)
    print(b)
    print(c)


if __name__ == '__main__':
    main()

Run the model

$ python main.py
Tensor(shape=(3, 4), dtype='float32', device='cuda:0')
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]
Tensor(shape=(4, 5), dtype='float32', device='cuda:0')
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
Tensor(shape=(3, 5), dtype='float32', device='cuda:0')
[[4. 4. 4. 4. 4.]
 [4. 4. 4. 4. 4.]
 [4. 4. 4. 4. 4.]]

@yaoyaoding yaoyaoding merged commit f5a36b7 into hidet-org:main Mar 6, 2023
@yaoyaoding yaoyaoding deleted the external branch March 6, 2023 22:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant