Skip to content

[Fusion] Allow shallow fusion of cublas operator#407

Merged
yaoyaoding merged 1 commit intohidet-org:mainfrom
yaoyaoding:fusion-cublas
Jan 4, 2024
Merged

[Fusion] Allow shallow fusion of cublas operator#407
yaoyaoding merged 1 commit intohidet-org:mainfrom
yaoyaoding:fusion-cublas

Conversation

@yaoyaoding
Copy link
Copy Markdown
Member

This PR enhance the fusion capability of hidet. Previously, if a task generate an IRModule that can not fuse prologue and epilogue (e.g., there are usage other than TensorElement and BufferStore), hidet will reject to fuse and throw an error. In this PR, we provide a fall back machanism for this case by generating kernels for those prologues and epilogues. With this enhancement, we can use opaque kernels (like cublas) and hidet's kernels in a single search space and enable auto-tuning among vendor's kernels and hidet's kernels.

For the following code

bs, m, n, k = 2, 1024, 1024, 1024
a = hidet.symbol(shape=[bs, m, k], dtype='float32', device='cuda')
b = hidet.randn(shape=[bs, k, n], dtype='float32', device='cuda')

def optimize_and_build(op):
    c = op(a + 1.0, b) + 1.0
    graph = hidet.trace_from(c)
    graph_opt = hidet.graph.optimize(graph)
    compiled = graph_opt.build()
    return compiled

graph_2 = optimize_and_build(hidet.ops.matmul_cublas)
graph_1 = optimize_and_build(hidet.ops.batch_matmul)

a = hidet.randn_like(a)

y1 = graph_1(a)
y2 = graph_2(a)

hidet.utils.assert_close(y1, y2, atol=1e-5, rtol=1e-5)

We can generate the fused kernel looks like

#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/logging.h>
#include <hidet/runtime/cuda/cublas.h>


static __global__ void __launch_bounds__(512) hidet_module_0_fused_sub_graph_compute_y(float * __restrict__ x, float * __restrict__ y) {
  y[(((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 1048576) * 1048576) + ((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 1024) % 1024) * 1024)) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 1024))] = (x[(((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 1048576) * 1048576) + ((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 1024) % 1024) * 1024)) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 1024))] + 1.0f);
}

static __global__ void __launch_bounds__(512) hidet_module_2_fused_sub_graph_compute_y(float * __restrict__ x, float * __restrict__ y) {
  y[(((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 1048576) * 1048576) + ((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 1024) % 1024) * 1024)) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 1024))] = (x[(((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 1048576) * 1048576) + ((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 1024) % 1024) * 1024)) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 1024))] + 1.0f);
}

DLL void hidet_module_0_launch(float * __restrict__ x, float * __restrict__ y) {
  hidet_module_0_fused_sub_graph_compute_y<<<dim3(16384, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(x, y);
  {cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) LOG(ERROR) << "CUDA error: " << cudaGetErrorString(err) << "\n";}
}

DLL void hidet_module_1_launch(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c) {
  hidet_cublas_strided_gemm(8, 1024, 1024, 1024, 0, 0, 0, a, b, c, 1048576, 1048576, 1048576, false, false, 68);
}

DLL void hidet_module_2_launch(float * __restrict__ x, float * __restrict__ y) {
  hidet_module_2_fused_sub_graph_compute_y<<<dim3(16384, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(x, y);
  {cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) LOG(ERROR) << "CUDA error: " << cudaGetErrorString(err) << "\n";}
}

DLL void hidet_launch_0(float * __restrict__ p, float * __restrict__ p_1, float * __restrict__ p_2) {
  float *buf;
  float *buf_1;
  uint8_t *workspace;
  workspace = ((uint8_t*)(request_cuda_workspace(int64_t(67108864ll), false)));
  buf = ((float*)((&workspace[int64_t(0ll)])));
  buf_1 = ((float*)((&workspace[int64_t(33554432ll)])));
  hidet_module_0_launch(p_1, buf);
  hidet_module_1_launch(buf, p, buf_1);
  hidet_module_2_launch(buf_1, p_2);
}

@yaoyaoding yaoyaoding merged commit c54eaa9 into hidet-org:main Jan 4, 2024
@yaoyaoding yaoyaoding deleted the fusion-cublas branch January 4, 2024 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant