Skip to content

kingsleykimm/tiny_infer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tiny Infer

small inference library for CC 9.0 NVGPUs, with a focus on minimal dependencies besides CuTe/CuTLASS on the kernel side

Today's Agenda:

  • test all2all kernels

📝 Roadmap

  • Implement GDN forward pass

    • Then need to implement Gated Delta Rule kernel

    • Improve TMA staging, maybe make two different staging barriers for chunk_fwd_O

    • Implement kv / state caching for the gdn forward pass, using DynamicCache

    • Implement MoE + expert parallelism

    • Implement more general reduction kernels, which use up to cluster reduction based on size (look at quack for reference)

      • Add L2 qk-norm in chunked_forward()
      • Add the post attention norm
      • for any max reduction kernels, it is probably faster to use the redux.max.sync.u32 instruction instead of the butterly reduction pattern
  • Implement MoE

    • Worklog through VLLM implementation
      • Kind of understand the larger structure of the MoE kernels, but we are focusing mainly on the DeepGEMM kernel backend, since this is more specialized
      • Current step: Getting the MoE weights set up for the Grouped GEMM, we need to transpose scales to MN major, TMA-aligned
      • finish a2a kenrel set up, integrate into api
      • testing for moe:
        • link topk with a2a dispatch and test
        • test full moe kernels
          • first part is just fp8 kernel on FC1, with concat
          • we then need to write a silu + mul + quant kernel (look at flashinfer's contiguous to fuse it in)
          • fp8 inputs go again into the down_proj, we get bf16 activations, so test this full loop
        • Note - -when using manual virtual memory management from CUDA, Hazy research recommends allocating all device memory that will be used in peer ALL2ALL before the model forward pass, and share it across all devices, since the host side memory share takes a lot of cycles per layer
      • attention kernels (with optional gating):
        • rework fmha attention - current kernel can be used for prefill
        • write flash decode kernel
  • Gated Attention

    • All that is different from normal attention looks like an output gate that can be fused into the existing kernel.
  • Multi-Token Prediction

    • This is just a recrusive MTP module, essentially an unrolled RNN
    • Allow user to specify spec-token amount, and then perform K passes of MTP module, and then one large forward pass verification
  • Split-KV / Flash Decoding

    • Implement an optimized bf16 attention kernel for decoding scenarios with small batch size
  • Implement cluster reduction for all reduction kernels

  • Batch inference

    • Augment sm90_attention to take in an optional mask for padding, should be straightforward
    • Augment attention and other kernels to have var len properties
  • insert cuda graphs into inference

    • Set up padded tensors for most of inference to aovid allocations
  • Fix multicast SwiGLU and Hopper Attention for edge case (m = seq_len) cases

  • Varlen Causal Conv1d Update kernel needs to be implemented

  • Quantized inference

  • Implement Gated Delta Net -> Kimi Linear Attention

  • KV Cache quantization

Running list of dependencies:

About

CUDA inference engine

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors