Skip to content

[Enquiry] developing Flash Attention Transformer example using Hidet #281

@keneoneth

Description

@keneoneth

Hello guys, really appreciate your work on Hidet. It is an awesome tool and it really makes developer's life easier when writing custom schedule for their CUDA kernel for performance optimization👍👍!

To test on Hidet's features, I am currently writing an example of the Flash Attention Transformer (link to research work: https://arxiv.org/abs/2205.14135) using the Hidet tool stack. I have writteb my custom testing setup (which contains my own host/device memory allocation & performance tracking & precision comparison code) in my "flash_attention_main.cu", and I am trying to call the kernel functions in Hidet generated cuda dynamic library.

May I know if there is a standard way of doing this? I tried using "dlopen" to load the library and launch the kernel functions but unfortunately it is not working properly. I therefore just manually copied the Hidet generated cuda source code to two separate header files "flash_attention_kernel_func.h" and "normal_transformer_kernel_func.h" and include them in my "flash_attention_main.cu". And I directly compile "flash_attention_main.cu" and everything works properly as well.

Let me share some source code below for illustration.

Here is my flash_attention_example.py, which includes the flash attention custom schedule and the normal approach.

import os
import math
import time
import numpy as np
import torch
import torch.nn as nn
torch.manual_seed(123)

# NOTE: this script is a simplified implementation of the following research work using Hidet
# Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35, 16344-16359.
# link to paper: https://arxiv.org/abs/2205.14135

import hidet
from hidet.ir.compute import compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule
from hidet.ir.primitives.cuda.atomic import atomic_add
from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.graph.ops.definitions.utils import input_like
from hidet.ir.expr import cast, address
from hidet.ir.primitives import exp, max, printf

# define Flash Attention Task
class FlashAttentionTask(Task):

    def allow_epilogue(self) -> bool:
        return False

    def flash_attention_implement_cuda(self, working_dir: str) -> IRModule:
        # override this method to use template-based scheduling
        return flash_attention_schedule(self)
    
    # Require: Matrices Q�K�V Nxd in HBM, on-chip SRAM of size M.
    # NOTE: typical SRAM size 100 kB, default to 48 kB
    # NOTE: max thread num is set to 1024
    def __init__(self,N=512,d=128,H=16,B=1,M=48*1024,ratio=12,max_thread_num=1024,disable_flash_attention=False):

        # 1. set block sizes Bc = ceil(M/(4d)), Br = min(M/(4d),d)
        Bc = math.ceil(M/(ratio*d))
        Br = min(math.ceil(M/(ratio*d)),d)
        Tr = math.ceil(N/Br)
        Tc = math.ceil(N/Bc)
        GLOBAL_Q = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_Q')
        GLOBAL_K = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_K')
        GLOBAL_V = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_V')
        
        def normal_transformer():
            matmulQK = compute(
                    name = 'GLOBAL_QK',
                    shape = [N, N],
                    fcompute = lambda i, j: reduce(
                        shape=[d],
                        fcompute=lambda k: GLOBAL_Q[i, k] * GLOBAL_K[j, k],
                        reduce_type='sum',
                    )
                )

            max_val = lambda i : reduce(shape=[N], fcompute=lambda j: matmulQK[i,j], reduce_type='max')
            S = compute(
                    name = 'S',
                    shape = [N, N],
                    fcompute = lambda i,j: matmulQK[i,j] - max_val(i)
                )
            exp_s = compute(
                    name = 'exp_s',
                    shape = [N, N],
                    fcompute = lambda i,j: exp(S[i,j])
                )
            exp_sum = lambda i : reduce(shape=[N], fcompute=lambda j: exp_s[i,j], reduce_type='sum')
            softmax = compute('softmax', shape=[N,N], fcompute=lambda i,j: exp_s[i,j] / exp_sum(i))
            matmulPV = compute(
                    name = 'GLOBAL_O',
                    shape = [N, d],
                    fcompute = lambda i, j: reduce(
                        shape=[N],
                        fcompute=lambda k: softmax[i, k] * GLOBAL_V[k, j],
                        reduce_type='sum',
                    )
                )
            return matmulPV
        
        super().__init__(
            name='flash_attention_task',
            inputs=[GLOBAL_Q,GLOBAL_K,GLOBAL_V],
            outputs=[normal_transformer()],
            attributes={
                'B' : B,
                'H' : H,
                'N' : N,
                'd' : d,
                'Bc' : Bc,
                'Br' : Br,
                'Tc' : Tc,
                'Tr' : Tr,
                'BLK' : Tr,
                'THD' : Br * Bc,
                'MAX_THD' : max_thread_num
            },
        )
        if not disable_flash_attention:
            self.implement_cuda = self.flash_attention_implement_cuda
            self.define = "-DRUN_FLASH_ATTN"
        else:
            self.define = ""

# define custom schedule
def flash_attention_schedule(task:FlashAttentionTask) -> IRModule:
    
    print_debug = False

    B = task.attrs['B']
    H = task.attrs['H']
    N = task.attrs['N']
    d = task.attrs['d']
    Bc = task.attrs['Bc']
    Br = task.attrs['Br']
    Tr = task.attrs['Tr']
    Tc = task.attrs['Tc']

    dims = ( task.attrs['BLK'] )
    threads = task.attrs['THD']
    assert threads <= task.attrs['MAX_THD'], f'err: {threads} not < {task.attrs["MAX_THD"]}'
    assert d % Bc == 0, f'err: Bc is not divisible by d'
    assert d % Br == 0, f'err: Br is not divisible by d'


    largest_fp16_value = 65504

    print(f'task.attrs {task.attrs}')
    
    
    # define the tensor program
    with hidet.script_module() as module:
        """Flash attention kernel."""

        @hidet.script
        def QK_matmul_compute(A:f16[Br,d],B:f16[d,Bc],C:f16[Br,Bc]):
            for m,n in spatial(Br,Bc).on(threadIdx.x):
                C[m,n] = 0.0
            syncthreads()
            for m,k,n in spatial(Br,1,Bc).repeat(1,d,1).on(threadIdx.x):   
                atomic_add(~C[m,n],A[m,k] * B[k,n])
            syncthreads()

        @hidet.script
        def PV_matmul_compute(A:f16[Br,Bc],B:f16[Bc,d],C:f16[Br,d]):
            for m,n in spatial(Br,Bc).repeat(1,d//Bc).on(threadIdx.x):
                C[m,n] = 0.0
            syncthreads()
            for m,k,n in spatial(Br,1,Bc).repeat(1,Bc,d//Bc).on(threadIdx.x):   
                atomic_add(~C[m,n],A[m,k] * B[k,n])
            syncthreads()

        @hidet.script
        def rowmax_compute(A:f16[Br,Bc],M:f16[Br],T:f16[Br,Bc]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                T.write([i,j],A[i,j],protected=True)
            syncthreads()

            for i,j in spatial(Br,Bc).on(threadIdx.x):
                k = 1
                while k < Bc:
                    if j % (k*2) == 0:
                        T.write([i,j],max(T[i,j],T[i,j+k]),protected=True)
                    syncthreads()
                    k *= 2

            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    M[i] = T[i,0]
            syncthreads()

        @hidet.script
        def rowsum_compute(A:f16[Br,Bc],L:f16[Br],T:f16[Br,Bc]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                T.write([i,j],A[i,j],protected=True)
            syncthreads()

            for i,j in spatial(Br,Bc).on(threadIdx.x):
                k = 1
                while k < Bc:
                    if j % (k*2) == 0:
                        T.write([i,j],(T[i,j]+T[i,j+k]),protected=True)
                    syncthreads()
                    k *= 2

            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    L[i] = T[i,0]
            syncthreads()

        @hidet.script
        def local_softmax_compute(S:f16[Br,Bc],M:f16[Br]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                if False and blockIdx.x==0:
                    printf("S[i,j] before %d %d %d %d : %f - %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"),cast(M[i],"float32"))
                S[i,j] = exp(S[i,j] - M[i])
                if False and blockIdx.x==0:
                    printf("S[i,j] %d %d %d %d : %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"))
            syncthreads()
        
        @hidet.script
        def local_update_compute(M:f16[Br],M_new:f16[Br],M_local:f16[Br],L:f16[Br],L_new:f16[Br],L_local:f16[Br]):
            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    M_new[i] = max(M[i],M_local[i])
                    L_new[i] = exp(M[i] - M_new[i]) * L[i] + exp(M_local[i] - M_new[i]) * L_local[i]
            syncthreads()

        @hidet.script
        def global_update_compute(PV:f16[Br,d],O:f16[Br,d],M_local:f16[Br],M_new:f16[Br],M:f16[Br],L_new:f16[Br],L:f16[Br]):
            for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                O.write(
                    [i,j],
                    ((L_new[i]**-1) * (L[i]*exp(M[i]-M_new[i])) * O[i,j]) + (exp(M_local[i]-M_new[i]) * PV[i,j]),
                    protected=True
                )
            syncthreads()

        @hidet.script
        def flash_attention_kernel(
            Q: f16[N,d],
            K: f16[N,d],
            V: f16[N,d],
            O: f16[N,d]
        ):
            
            attr.cuda_grid_dim = dims
            attr.cuda_block_dim = threads

            # Init O=(0), N x d in HBM
            for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                offset_i = blockIdx.x * (Br)
                O[offset_i:,:].write([i,j], 0, protected=True)
            syncthreads()

            smem_q = tensor('shared', 'float16', [Br, d])
            smem_k = tensor('shared', 'float16', [d, Bc]) # transposed
            smem_v = tensor('shared', 'float16', [Bc, d])
            smem_o = tensor('shared', 'float16', [Br, d])
            
            smem_l = tensor('shared', 'float16', [Br])
            smem_l_local = tensor('shared', 'float16', [Br])
            smem_l_new = tensor('shared', 'float16', [Br])
            smem_m = tensor('shared', 'float16', [Br])
            smem_m_local = tensor('shared', 'float16', [Br])
            smem_m_new = tensor('shared', 'float16', [Br])
            smem_sp = tensor('shared', 'float16', [Br,Bc])
            smem_pv = tensor('shared', 'float16', [Br,d])
            smem_temp = tensor('shared', 'float16', [Br,Bc])

            for a,b in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                # load Qi from HBM to on-chip SRAM
                # initialization of o,l,m
                offset_i = blockIdx.x * (Br)
                smem_q[a,b] = Q[offset_i:,:].read([a,b],protected=True)
                smem_o[a,b] = 0
                smem_l[a] = 0
                smem_m[a] = -largest_fp16_value
            syncthreads()

            if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                idx = 0
                for i,j in grid(Br,d):
                    printf("idx: %d, Q val: %f\n",idx,cast(smem_q[i,j],"float32"))
                    idx += 1
            syncthreads()

            for j in grid(Tc):

                for a,b in spatial(Bc,Br).repeat(1,(d//Br)).on(threadIdx.x):
                    # load Kj,Vj from HBM to on-chip SRAM
                    offset_j = j * (Bc)
                    smem_k[b,a] = K[offset_j:,:].read([a,b],protected=True)
                    smem_v[a,b] = V[offset_j:,:].read([a,b],protected=True)
                syncthreads()
                
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(d,Bc):
                        printf("idx: %d, K val: %f\n",idx,cast(smem_k[i,j],"float32"))
                        idx += 1
                    for i,j in grid(Bc,d):
                        printf("idx: %d, V val: %f\n",idx,cast(smem_v[i,j],"float32"))
                        idx += 1
                syncthreads()
                
                # on chip, compute Sij = Qi @ (Kj)^T, Br X Bc
                QK_matmul_compute(smem_q,smem_k,smem_sp)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,Bc):
                        printf("idx: %d, S val: %f\n",idx,cast(smem_sp[i,j],"float32"))
                        idx += 1
                syncthreads()
                
                # on chip, compute m'_ij = rowmax(Sij), Br; Pij = exp(Sij - m'_ij), Br x Bc (pointwise); l'_ij = rowsum(P'ij), Br
                rowmax_compute(smem_sp,smem_m_local,smem_temp)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d M val: %f\n",i,cast(smem_m_local[i],"float32"))
                        # for j in grid(Bc):
                        #     printf("j: %d, S val: %f\n",j,cast(smem_sp[i,j],"float32"))
                syncthreads()
                
                local_softmax_compute(smem_sp,smem_m_local)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,Bc):
                        printf("idx: %d, P val: %f\n",idx,cast(smem_sp[i,j],"float32"))
                        idx += 1
                syncthreads()


                rowsum_compute(smem_sp,smem_l_local,smem_temp)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d L val: %f\n",i,cast(smem_l_local[i],"float32"))
                        # for j in grid(Bc):
                        #     printf("j: %d, P val: %f\n",j,cast(smem_sp[i,j],"float32"))
                syncthreads()

                
                # on chip, compute m_new_i = max(m_i,m'_ij), Br; l_new_i = e^(m_i - m_new_i) * l_i + e^(m'_ij - m_i_new) * l'_ij, Br
                local_update_compute(smem_m,smem_m_new,smem_m_local,smem_l,smem_l_new,smem_l_local)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d smem_m val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_m_new val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_m_local val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l_new val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l_local val: %f\n",i,cast(smem_m[i],"float32"))
                syncthreads()
                # write Oi = diag(l_i_new)^-1 * (diag(l_i)*e^(m_i-m_i_new) @ Oi + e^*m'_ij-m_i_new*(P'ij @ Vj))

                PV_matmul_compute(smem_sp,smem_v,smem_pv)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,d):
                        printf("idx: %d, PV val: %f\n",idx,cast(smem_pv[i,j],"float32"))
                        idx += 1
                syncthreads()

                global_update_compute(smem_pv,smem_o,smem_m_local,smem_m_new,smem_m,smem_l_new,smem_l)

                if j + 1 == Tc:
                    for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                        offset_i = blockIdx.x * (Br)
                        O[offset_i:,:].write([i,j], smem_o[i,j], protected=True)
                    syncthreads()

                # write l_i = l_i_new, m_i = m_i_new
                for i in spatial(Br).on(threadIdx.x):
                    if threadIdx.x < Br:
                        smem_m[i] = smem_m_new[i]
                        smem_l[i] = smem_l_new[i]
                syncthreads()

            if print_debug and (blockIdx.x==15 and threadIdx.x==0):
                idx = 0
                for i,j in grid(Br,d):
                    offset_i = blockIdx.x * (Br)
                    printf("blockIdx %d : output idx: %d, val: %f\n",blockIdx.x,idx,cast(O[offset_i+i,j],"float32"))
                    idx += 1
            syncthreads()
            return

        @hidet.script
        def flash_attention_launch_func( 
            G_Q: f16[B, H, N, d],
            G_K: f16[B, H, N, d],
            G_V: f16[B, H, N, d],
            G_O: f16[B, H, N, d]
        ):
            # NOTE: this section needs to be written in flash_attention_main.cu
            for b,h in grid(B,H):
                flash_attention_kernel(
                    address(G_Q[b,h,0,0]),
                    address(G_K[b,h,0,0]),
                    address(G_V[b,h,0,0]),
                    address(G_O[b,h,0,0])
                )
            
    # build ir module
    ir_module = module.ir_module()
    return ir_module

# gen Python gold data as reference
def gen_gold(attrs,r1=-3,r2=3):

    Q = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    K = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    V = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    t = time.process_time()
    Q.half().numpy().tofile('mat_Q.bin')
    K.half().numpy().tofile('mat_K.bin')
    V.half().numpy().tofile('mat_V.bin')
    S = torch.from_numpy(Q.numpy() @ torch.transpose(K, -2, -1).numpy())

    row_max, _ = torch.max(S,dim=-1)
    S = torch.from_numpy(np.exp((S - row_max.reshape(attrs['B'],attrs['H'],attrs['N'],1)).numpy()))
    row_sum = torch.sum(S,dim=-1).reshape(attrs['B'],attrs['H'],attrs['N'],1)
    P = S / row_sum

    # TODO: test with softmax float precision
    # P = nn.Softmax(dim=-1)(S.float()).half()
    O = torch.from_numpy(P.numpy() @ V.numpy())
    elapsed_time = (time.process_time() - t)*1000
    print(f"Python gold gen run elapsed time {round(elapsed_time,3)} msec")
    O.half().numpy().tofile('gold_mat_O.bin')

# run task
def run_task(disable_flash_attention=False):
    # define the task here
    flash_attention_task = FlashAttentionTask(disable_flash_attention=disable_flash_attention)
    # build the task
    ret = flash_attention_task.build(target='cuda')

    # copy source file and lib to current directory
    source_path = ret.src_path
    library_path = ret.lib_path
    print(f'source_path {source_path} library_path {library_path}')

    import shutil
    shutil.copyfile(source_path,os.path.join("./","flash_attention_"+os.path.basename(source_path)))
    shutil.copyfile(library_path,os.path.join("./","flash_attention_"+os.path.basename(library_path)))

    # generate golden data
    gen_gold(flash_attention_task.attrs)

    def exe_f(command='', shell=True):
        print(f'running {command}')
        import subprocess
        process = subprocess.Popen(command, shell=shell)
        code = process.wait()
        process.communicate()
        return code
    
    # launch testcase flash_attention_main.cu
    HIDET_CUDA_INCLUDE_PATH = "../cuda-samples-master/Common/"
    CUDA_SAMPLES_INCLUDE_PATH = "../../include/"
    ret = exe_f(f'nvcc flash_attention_main.cu {flash_attention_task.define} -gencode arch=compute_86,code=sm_86 -I {CUDA_SAMPLES_INCLUDE_PATH} -I {HIDET_CUDA_INCLUDE_PATH} -std=c++11 -o fa.out && ./fa.out -BATCH={flash_attention_task.attrs["B"]} -HEAD={flash_attention_task.attrs["H"]} -BLK={flash_attention_task.attrs["BLK"]} -THD={flash_attention_task.attrs["THD"]}')
    print('test done' if ret==0 else 'test error')

# main function
if __name__ == '__main__':
    # normal approach execution
    run_task(disable_flash_attention=True)
    # flash attention approach execution
    run_task(disable_flash_attention=False)

Here is my flash_attention_main.cu, which includes the performance tracking, precision comparison & memory allocation operations, and it lauches the test kernels.

// System includes
#include <stdio.h>
#include <sys/stat.h>
#include <dlfcn.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <assert.h>
#include <vector>

// CUDA runtime
#include <cuda_runtime.h>
#include <cuda_profiler_api.h>

// Helper functions and utilities to work with CUDA,
#include <helper_functions.h>
#include <helper_cuda.h>
#include <cuda_fp16.h>

// Import kernel functions
#include "flash_attention_kernel_func.h"
#include "normal_transformer_kernel_func.h"


// test function, execute kernel, compare with gold data
int flash_attention_test(
    unsigned int B, unsigned int H,
    unsigned int block_size, unsigned int thread_size,
    half *h_Q, unsigned int size_Q,
    half *h_K, unsigned int size_K,
    half *h_V, unsigned int size_V,
    half *h_gold_O, unsigned int size_O)
{

    cudaStream_t stream;
    const unsigned int BH = B * H;
    // Allocate device memory
    half *d_Q, *d_K, *d_V, *d_O, *h_O;
    checkCudaErrors(cudaMallocHost(&h_O, size_O * sizeof(half)));

    if (h_O == NULL)
    {
        fprintf(stderr, "Failed to allocate host matrix O!\n");
        exit(EXIT_FAILURE);
    }

    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_Q), size_Q * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_K), size_K * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_V), size_V * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_O), size_O * sizeof(half)));
    // Allocate CUDA events that we'll use for timing
    cudaEvent_t start, stop;
    checkCudaErrors(cudaEventCreate(&start));
    checkCudaErrors(cudaEventCreate(&stop));

    checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));

    // copy host memory to device
    checkCudaErrors(
        cudaMemcpyAsync(d_Q, h_Q, size_Q * sizeof(half), cudaMemcpyHostToDevice, stream));
    checkCudaErrors(
        cudaMemcpyAsync(d_K, h_K, size_K * sizeof(half), cudaMemcpyHostToDevice, stream));
    checkCudaErrors(
        cudaMemcpyAsync(d_V, h_V, size_V * sizeof(half), cudaMemcpyHostToDevice, stream));

    const unsigned int k_size_Q = (size_Q / BH);
    const unsigned int k_size_K = (size_K / BH);
    const unsigned int k_size_V = (size_V / BH);
    const unsigned int k_size_O = (size_O / BH);
    printf("k_size_Q %u k_size_K %u k_size_V %u k_size_O %u\n", k_size_Q, k_size_K, k_size_V, k_size_O);

    // Record the start event
    checkCudaErrors(cudaEventRecord(start, stream));

    const int32_t num_args = 4;

    for (unsigned int b = 0; b < B; b++)
    {
        for (unsigned int h = 0; h < H; h++)
        {
            unsigned int offset_index = (b * H) + h;

            half *param[num_args] = {
                d_Q + offset_index * k_size_Q,
                d_K + offset_index * k_size_K,
                d_V + offset_index * k_size_V,
                d_O + offset_index * k_size_O};

#ifdef RUN_FLASH_ATTN
            // run flash attention kernel
            flash_attention_kernel<<<dim3(16, 1, 1), dim3(1024, 1, 1), 0, (cudaStream_t)stream>>>(((half *)(param[0])), ((half *)(param[1])), ((half *)(param[2])), ((half *)(param[3])));
#else
            // run normal transformer kernel
            uint8_t *buffer;

            checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&buffer), int64_t(2097152ll)));

            half *GLOBAL_QK = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(0ll) * ((int64_t)(1))))]));
            half *S = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(524288ll) * ((int64_t)(1))))]));
            half *exp_s = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(1048576ll) * ((int64_t)(1))))]));
            half *softmax = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(1572864ll) * ((int64_t)(1))))]));

            hidet_compute_GLOBAL_QK<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(param[0], param[1], GLOBAL_QK);
            hidet_compute_S<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(GLOBAL_QK, param[0], param[1], S);
            hidet_compute_exp_s<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(S, exp_s);
            hidet_compute_softmax<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(S, param[0], param[1], GLOBAL_QK, exp_s, softmax);
            hidet_compute_GLOBAL_O<<<dim3(128, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(softmax, ((half *)(param[2])), param[0], param[1], GLOBAL_QK, S, exp_s, ((half *)(param[3])));
#endif // RUN_FLASH_ATTN

        }
    }

    checkCudaErrors(cudaStreamSynchronize(stream));

    // Record the stop event
    checkCudaErrors(cudaEventRecord(stop, stream));
    printf("test done !!!\n");

    // Wait for the stop event to complete
    checkCudaErrors(cudaEventSynchronize(stop));

    float msecTotal = 0.0f;
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));

    // Compute and print the performance
#if RUN_FLASH_ATTN
    printf("flash attention elapsed time = %.3f msec\n", msecTotal);
#else
    printf("normal approach elapsed time = %.3f msec\n", msecTotal);
#endif // RUN_FLASH_ATTN
    // Copy result from device to host
    checkCudaErrors(
        cudaMemcpyAsync(h_O, d_O, size_O * sizeof(half), cudaMemcpyDeviceToHost, stream));
    checkCudaErrors(cudaStreamSynchronize(stream));

    printf("Checking computed result for correctness: \n");

    double eps = 0.01; // 1% error with python output
    const unsigned int max_print_count = 100;
    uint32_t total_count = 0;
    uint32_t total_err_count = 0;
    for (int i = 0; i < static_cast<int>(size_O); i++)
    {
        double gold_val = fabs((double)h_gold_O[i]);
        double abs_val = fabs((double)h_O[i]);
        double abs_err = fabs(abs_val - gold_val);
        double rel_err = abs_err / abs_val;

        if (rel_err > eps)
        {
            if (total_err_count < max_print_count)
                printf("Error! Matrix[%05d]=%.8f, ref=%.8f error term %E is > %E\n",
                       i, (double)h_O[i], (double)h_gold_O[i], rel_err, eps);
            total_err_count++;
        }
        total_count++;
    }
    double error_ratio = (double)total_err_count / (double)total_count;
    bool correct = error_ratio < eps;
    printf("total count %u total error count %u (%.8f %%)\n", total_count, total_err_count, error_ratio * 100);
    printf("%s\n", correct ? "Result = PASS" : "Result = FAIL");

    // Clean up memory
    checkCudaErrors(cudaFree(d_Q));
    checkCudaErrors(cudaFree(d_K));
    checkCudaErrors(cudaFree(d_V));
    checkCudaErrors(cudaFree(d_O));
    checkCudaErrors(cudaEventDestroy(start));
    checkCudaErrors(cudaEventDestroy(stop));

    if (correct)
    {
        return EXIT_SUCCESS;
    }
    else
    {
        return EXIT_FAILURE;
    }
}

inline bool file_exists(const std::string &name)
{
    struct stat buffer;
    return (stat(name.c_str(), &buffer) == 0);
}

void load_data(std::vector<half> &matrix, const std::string bin_file)
{
    printf("loading %s\n", bin_file.c_str());
    assert(file_exists(bin_file) && "Error! binary file doesn't exist");

    std::ifstream fin(bin_file, std::ios::binary);
    half elem;
    while (fin.read(reinterpret_cast<char *>(&elem), sizeof(half)))
    {
        matrix.push_back(elem);
    }
}

int main(int argc, char **argv)
{
    printf("[Flash Attention Using CUDA] - Starting...\n");

    if (checkCmdLineFlag(argc, (const char **)argv, "help") ||
        checkCmdLineFlag(argc, (const char **)argv, "?"))
    {

        printf("Usage -device=n (n >= 0 for deviceID)\n");
        printf("      -BATCH=number of Batch\n");
        printf("      -HEAD=number of Head\n");
        printf("      -BLK=block size\n");
        printf("      -THD=thread size\n");
        exit(EXIT_SUCCESS);
    }

    // This will pick the best possible CUDA capable device, otherwise
    // override the device ID based on input provided at the command line
    int dev = findCudaDevice(argc, (const char **)argv);

    unsigned int batch = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "BATCH"))
    {
        batch = getCmdLineArgumentInt(argc, (const char **)argv, "BATCH");
    }
    unsigned int head = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "HEAD"))
    {
        head = getCmdLineArgumentInt(argc, (const char **)argv, "HEAD");
    }
    unsigned int block_size = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "BLK"))
    {
        block_size = getCmdLineArgumentInt(argc, (const char **)argv, "BLK");
    }
    unsigned int thread_size = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "THD"))
    {
        thread_size = getCmdLineArgumentInt(argc, (const char **)argv, "THD");
    }

    // load Q
    std::vector<half> mat_Q;
    load_data(mat_Q, "./mat_Q.bin");

    // load K
    std::vector<half> mat_K;
    load_data(mat_K, "./mat_K.bin");

    // load V
    std::vector<half> mat_V;
    load_data(mat_V, "./mat_V.bin");

    // load golden data O
    std::vector<half> gold_mat_O;
    load_data(gold_mat_O, "./gold_mat_O.bin");

    printf("batch %u head %u block_size %u thread_size %u\n", batch, head, block_size, thread_size);

    printf("Q size %lu K size %lu V size %lu O size %lu\n", mat_Q.size(), mat_K.size(), mat_V.size(), gold_mat_O.size());

    checkCudaErrors(cudaProfilerStart());
    int result = flash_attention_test(
        batch, head, block_size, thread_size,
        &mat_Q[0], mat_Q.size(),
        &mat_K[0], mat_K.size(),
        &mat_V[0], mat_V.size(),
        &gold_mat_O[0], gold_mat_O.size());
    checkCudaErrors(cudaProfilerStop());

    exit(result);
}

Here are the flash_attention_kernel_func.h and normal_transformer_func.h, respectively.

// flash_attention_kernel_func.h
__global__ void __launch_bounds__(1024) flash_attention_kernel(half *__restrict__ Q, half *__restrict__ K, half *__restrict__ V, half *__restrict__ O)
{
    for (int32_t i = 0; (i < 4); i = (i + 1))
    {
        O[(((((int)blockIdx.x * 32) + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i))] = ((half)(0));
    }
    __syncthreads();
    __shared__ half smem_q[4096];
    __shared__ half smem_k[4096];
    __shared__ half smem_v[4096];
    __shared__ half smem_o[4096];
    __shared__ half smem_l[32];
    __shared__ half smem_l_local[32];
    __shared__ half smem_l_new[32];
    __shared__ half smem_m[32];
    __shared__ half smem_m_local[32];
    __shared__ half smem_m_new[32];
    __shared__ half smem_sp[1024];
    __shared__ half smem_pv[4096];
    __shared__ half smem_temp[1024];
    for (int32_t i_1 = 0; (i_1 < 4); i_1 = (i_1 + 1))
    {
        smem_q[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_1))] = Q[(((((int)blockIdx.x * 32) + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_1))];
        smem_o[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_1))] = ((half)(0));
        smem_l[((int)threadIdx.x / 32)] = ((half)(0));
        smem_m[((int)threadIdx.x / 32)] = ((half)((-65504)));
    }
    __syncthreads();
    __syncthreads();
    for (int32_t j = 0; (j < 16); j = (j + 1))
    {
        for (int32_t i_2 = 0; (i_2 < 4); i_2 = (i_2 + 1))
        {
            int32_t offset_j = (j * 32);
            smem_k[((((((int)threadIdx.x % 32) * 4) + i_2) * 32) + ((int)threadIdx.x / 32))] = K[(((offset_j + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_2))];
            smem_v[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_2))] = V[(((offset_j + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_2))];
        }
        __syncthreads();
        __syncthreads();
        half *A = smem_q;
        half *B = smem_k;
        half *C = smem_sp;
        C[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = ((half)(0.0f));
        __syncthreads();
        for (int32_t i_3 = 0; (i_3 < 128); i_3 = (i_3 + 1))
        {
            atomicAdd(&C[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))], (A[((((int)threadIdx.x / 32) * 128) + i_3)] * B[((i_3 * 32) + ((int)threadIdx.x % 32))]));
        }
        __syncthreads();
        __syncthreads();
        half *A_1 = smem_sp;
        half *M = smem_m_local;
        half *T = smem_temp;
        T[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = A_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))];
        __syncthreads();
        int32_t k = 1;
        while ((k < 32))
        {
            if ((((int)threadIdx.x % 32) % (k * 2)) == 0)
            {
                T[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = __hmax(T[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))], T[((((int)threadIdx.x / 32) * 32) + (((int)threadIdx.x % 32) + k))]);
            }
            __syncthreads();
            k = (k * 2);
        }
        if ((int)threadIdx.x < 32)
        {
            M[((int)threadIdx.x % 32)] = T[(((int)threadIdx.x % 32) * 32)];
        }
        __syncthreads();
        __syncthreads();
        half *S = smem_sp;
        half *M_1 = smem_m_local;
        S[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = hexp((S[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] - M_1[((int)threadIdx.x / 32)]));
        __syncthreads();
        __syncthreads();
        half *A_2 = smem_sp;
        half *L = smem_l_local;
        half *T_1 = smem_temp;
        T_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = A_2[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))];
        __syncthreads();
        int32_t k_1 = 1;
        while ((k_1 < 32))
        {
            if ((((int)threadIdx.x % 32) % (k_1 * 2)) == 0)
            {
                T_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = (T_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] + T_1[((((int)threadIdx.x / 32) * 32) + (((int)threadIdx.x % 32) + k_1))]);
            }
            __syncthreads();
            k_1 = (k_1 * 2);
        }
        if ((int)threadIdx.x < 32)
        {
            L[((int)threadIdx.x % 32)] = T_1[(((int)threadIdx.x % 32) * 32)];
        }
        __syncthreads();
        __syncthreads();
        half *M_2 = smem_m;
        half *M_new = smem_m_new;
        half *M_local = smem_m_local;
        half *L_1 = smem_l;
        half *L_new = smem_l_new;
        half *L_local = smem_l_local;
        if ((int)threadIdx.x < 32)
        {
            M_new[((int)threadIdx.x % 32)] = __hmax(M_2[((int)threadIdx.x % 32)], M_local[((int)threadIdx.x % 32)]);
            L_new[((int)threadIdx.x % 32)] = ((hexp((M_2[((int)threadIdx.x % 32)] - M_new[((int)threadIdx.x % 32)])) * L_1[((int)threadIdx.x % 32)]) + (hexp((M_local[((int)threadIdx.x % 32)] - M_new[((int)threadIdx.x % 32)])) * L_local[((int)threadIdx.x % 32)]));
        }
        __syncthreads();
        __syncthreads();
        half *A_3 = smem_sp;
        half *B_1 = smem_v;
        half *C_1 = smem_pv;
        for (int32_t i_4 = 0; (i_4 < 4); i_4 = (i_4 + 1))
        {
            C_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_4))] = ((half)(0.0f));
        }
        __syncthreads();
        for (int32_t i_5 = 0; (i_5 < 32); i_5 = (i_5 + 1))
        {
            for (int32_t i_6 = 0; (i_6 < 4); i_6 = (i_6 + 1))
            {
                atomicAdd(&C_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_6))], (A_3[((((int)threadIdx.x / 32) * 32) + i_5)] * B_1[((i_5 * 128) + ((((int)threadIdx.x % 32) * 4) + i_6))]));
            }
        }
        __syncthreads();
        __syncthreads();
        half *PV = smem_pv;
        half *O_1 = smem_o;
        half *M_local_1 = smem_m_local;
        half *M_new_1 = smem_m_new;
        half *M_3 = smem_m;
        half *L_new_1 = smem_l_new;
        half *L_2 = smem_l;
        for (int32_t i_7 = 0; (i_7 < 4); i_7 = (i_7 + 1))
        {
            O_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_7))] = (((((half)(powf((float)(L_new_1[((int)threadIdx.x / 32)]), ((float)((-1)))))) * (L_2[((int)threadIdx.x / 32)] * hexp((M_3[((int)threadIdx.x / 32)] - M_new_1[((int)threadIdx.x / 32)])))) * O_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_7))]) + (hexp((M_local_1[((int)threadIdx.x / 32)] - M_new_1[((int)threadIdx.x / 32)])) * PV[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_7))]));
        }
        __syncthreads();
        if ((j + 1) == 16)
        {
            for (int32_t i_8 = 0; (i_8 < 4); i_8 = (i_8 + 1))
            {
                O[(((((int)blockIdx.x * 32) + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_8))] = smem_o[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_8))];
            }
            __syncthreads();
        }
        if ((int)threadIdx.x < 32)
        {
            smem_m[((int)threadIdx.x % 32)] = smem_m_new[((int)threadIdx.x % 32)];
            smem_l[((int)threadIdx.x % 32)] = smem_l_new[((int)threadIdx.x % 32)];
        }
        __syncthreads();
    }
    __syncthreads();
    return;
}
// normal_transformer_func.h
__global__ void __launch_bounds__(512) hidet_compute_GLOBAL_QK(half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ GLOBAL_QK) {
  half acc_Sum = half(0.0f);
  for (int32_t v = 0; (v < 128); v = (v + 1)) {
    acc_Sum = (acc_Sum + (GLOBAL_Q[(((int)blockIdx.x * 128) + v)] * GLOBAL_K[(((int)threadIdx.x * 128) + v)]));
  } 
  GLOBAL_QK[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = acc_Sum;
}

__global__ void __launch_bounds__(512) hidet_compute_S(half * __restrict__ GLOBAL_QK, half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ S) {
  half acc_Max = half(-65504.0f);
  for (int32_t v = 0; (v < 512); v = (v + 1)) {
    acc_Max = __hmax(acc_Max, GLOBAL_QK[(((int)blockIdx.x * 512) + v)]);
  } 
  S[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = (GLOBAL_QK[(((int)blockIdx.x * 512) + (int)threadIdx.x)] - acc_Max);
}

__global__ void __launch_bounds__(512) hidet_compute_exp_s(half * __restrict__ S, half * __restrict__ exp_s) {
  exp_s[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = hexp(S[(((int)blockIdx.x * 512) + (int)threadIdx.x)]);
}

__global__ void __launch_bounds__(512) hidet_compute_softmax(half * __restrict__ S, half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ GLOBAL_QK, half * __restrict__ exp_s, half * __restrict__ softmax) {
  half acc_Sum = half(0.0f);
  for (int32_t v = 0; (v < 512); v = (v + 1)) {
    acc_Sum = (acc_Sum + hexp(S[(((int)blockIdx.x * 512) + v)]));
  } 
  softmax[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = (hexp(S[(((int)blockIdx.x * 512) + (int)threadIdx.x)]) / acc_Sum);
}

__global__ void __launch_bounds__(512) hidet_compute_GLOBAL_O(half * __restrict__ softmax, half * __restrict__ GLOBAL_V, half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ GLOBAL_QK, half * __restrict__ S, half * __restrict__ exp_s, half * __restrict__ GLOBAL_O) {
  half acc_Sum = half(0.0f);
  for (int32_t v = 0; (v < 512); v = (v + 1)) {
    acc_Sum = (acc_Sum + (softmax[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 128) * 512) + v)] * GLOBAL_V[((v * 128) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 128))]));
  } 
  GLOBAL_O[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 128) * 128) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 128))] = acc_Sum;
}

Again, really wonderful work on Hidet! And any help will be well appreciated 🙏 Or if any further info. is needed, please let me know.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions