<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://muqi1029.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://muqi1029.github.io/" rel="alternate" type="text/html" /><updated>2025-08-10T23:31:21+08:00</updated><id>https://muqi1029.github.io/feed.xml</id><title type="html">Muqi Li(李琦)</title><subtitle>personal description</subtitle><author><name>Muqi Li</name></author><entry><title type="html">SGLang Memory Management &amp;amp; Cache</title><link href="https://muqi1029.github.io/posts/2025/05/mem_cache/" rel="alternate" type="text/html" title="SGLang Memory Management &amp;amp; Cache" /><published>2025-05-30T00:00:00+08:00</published><updated>2025-05-30T00:00:00+08:00</updated><id>https://muqi1029.github.io/posts/2025/05/sglang_mem_cache</id><content type="html" xml:base="https://muqi1029.github.io/posts/2025/05/mem_cache/"><![CDATA[<blockquote>
  <p>Note: Complex systems often include numerous corner cases and technical implementations that can make the source code challenging to understand for newcomers.</p>

  <p>To make the core concepts more accessible, this blog post uses pseudocode that focuses on the main ideas while omitting implementation details (such as <code class="language-plaintext highlighter-rouge">self</code> references and other technical specifics). While simplified, the pseudocode maintains the essential logic and workflow of the system.</p>

  <p>Of source, if you want to know all details, the best way is to look directly at the source code, which is available <a href="https://github.com/sgl-project/sglang">here</a></p>
</blockquote>

<p>Main walker:</p>

<p><code class="language-plaintext highlighter-rouge">launch_server</code> ⇒ <code class="language-plaintext highlighter-rouge">_launch_subprocesses</code> ⇒ <code class="language-plaintext highlighter-rouge">Init Scheduler</code> ⇒ <code class="language-plaintext highlighter-rouge">Init TpWorker</code> ⇒ <code class="language-plaintext highlighter-rouge">Init ModelConfig &amp; ModelRunner</code> ⇒ <code class="language-plaintext highlighter-rouge">ModelRunner init KV Cache Pool &amp; Allocator</code></p>

<p>Main points in this blog:</p>

<ul>
  <li>How <code class="language-plaintext highlighter-rouge">mem-fraction-static</code> works in the KV Cache Initiation</li>
  <li>How is each token’s <code class="language-plaintext highlighter-rouge">KV Cache</code> computed</li>
  <li>How <code class="language-plaintext highlighter-rouge">KV Cache Pool</code> are managed(allocate, free, use)</li>
  <li>How <code class="language-plaintext highlighter-rouge">Radix Cache</code> reuses KV Cache</li>
</ul>

<p>This blog mainly contains 2 sections</p>

<ul>
  <li>In the KV Cache Management section, we will explore how <code class="language-plaintext highlighter-rouge">KV Cache</code> is managed (creation, allocation, free, and usage)</li>
  <li>In the Radix Tree Cache section, we will explore how the <code class="language-plaintext highlighter-rouge">radix tree</code> data structure enables KV Cache reuse</li>
</ul>

<h1 id="kv-cache-management">KV Cache Management</h1>

<blockquote>
  <p><strong>Background</strong>
The <code class="language-plaintext highlighter-rouge">ModelRunner</code>: owns real models, runs the <strong>forward</strong> pass of models</p>
</blockquote>

<p>here is the initialization of <code class="language-plaintext highlighter-rouge">ModelRunner</code> , and also the initialization of <code class="language-plaintext highlighter-rouge">KV Cache Pool</code></p>

<p>In this process of initiating <code class="language-plaintext highlighter-rouge">memory pool</code> , SGLang provides 3 abstract managers</p>

<ol>
  <li><code class="language-plaintext highlighter-rouge">req_to_token_pool</code>: A memory pool that maps a request’s tokens to <code class="language-plaintext highlighter-rouge">out_cache_loc</code></li>
  <li><code class="language-plaintext highlighter-rouge">token_to_kv_pool</code>: A pool that maps <code class="language-plaintext highlighter-rouge">out_cache_loc</code> from <code class="language-plaintext highlighter-rouge">req_token_pool</code> to its real KV Cache data</li>
  <li><code class="language-plaintext highlighter-rouge">token_to_kv_pool_allocator</code>: Allocate and free real KV Cache data</li>
</ol>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ModelRunner</span><span class="p">:</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_config</span><span class="p">,</span> <span class="p">....):</span>

    <span class="c1"># adjust `AttentionBackend`, `mem_fraction_static`
</span>    <span class="n">model_specific_adjustment</span><span class="p">()</span>

    <span class="c1"># since SGLang adjusts the settings depending on Model Arch
</span>    <span class="c1"># then update that info globally
</span>    <span class="n">global_server_args_dict</span><span class="p">.</span><span class="n">update</span><span class="p">({...})</span>

    <span class="c1"># build WORLD_GROUP, TP_GROUP, PP_GROUP for later communication
</span>    <span class="c1"># after init the distributed settings, get the minimum GPU memory across the world
</span>    <span class="n">min_per_gpu_memory</span> <span class="o">=</span> <span class="n">init_torch_distributed</span><span class="p">()</span>

    <span class="n">initialize</span><span class="p">(</span><span class="n">min_per_gpu_memory</span><span class="p">)</span>

  <span class="k">def</span> <span class="nf">initialize</span><span class="p">(</span><span class="n">min_per_gpu_memory</span><span class="p">):</span>

    <span class="c1"># load sampler and model
</span>    <span class="n">sampler</span> <span class="o">=</span> <span class="n">Sampler</span><span class="p">()</span>
    <span class="n">load_model</span><span class="p">()</span>

    <span class="c1">######
</span>    <span class="c1"># Until now, Model Weights &amp; Distributed Initialization occpuy some GPU memory
</span>    <span class="c1"># Note: but `min_per_gpu_memory` doesn't change
</span>    <span class="c1">######
</span>
    <span class="c1"># Core in this blog!!!
</span>    <span class="n">init_memory_pool</span><span class="p">(</span>
      <span class="n">min_per_gpu_memory</span><span class="p">,</span>
      <span class="n">server_args</span><span class="p">.</span><span class="n">max_running_requests</span><span class="p">,</span> <span class="c1"># these 2 args are set by users
</span>      <span class="n">server_args</span><span class="p">.</span><span class="n">max_total_tokens</span><span class="p">)</span>

    <span class="c1"># ...
</span>    <span class="n">init_cublas</span><span class="p">()</span>
    <span class="n">init_attention_backend</span><span class="p">()</span>
    <span class="n">init_cuda_graphs</span><span class="p">()</span>

  <span class="k">def</span> <span class="nf">init_memory_pool</span><span class="p">(</span>
       <span class="n">total_gpu_memory</span><span class="p">,</span>
       <span class="n">max_num_reqs</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
       <span class="n">max_total_tokens</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="c1"># compute how many token's KV Cache can be saved in each GPU
</span>    <span class="n">max_total_num_tokens</span> <span class="o">=</span> <span class="n">profile_max_num_token</span><span class="p">(</span><span class="n">total_gpu_memory</span><span class="p">)</span>

    <span class="c1"># adjust max_num_requests
</span>    <span class="k">if</span> <span class="n">max_num_reqs</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
      <span class="n">max_num_reqs</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span>
       <span class="nb">max</span><span class="p">(</span><span class="n">max_total_num_tokens</span> <span class="o">/</span> <span class="n">model_config</span><span class="p">.</span><span class="n">context_len</span> <span class="o">*</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">2048</span><span class="p">),</span>
       <span class="mi">4096</span>
    <span class="p">)</span>

    <span class="c1"># adjust max_total_tokens
</span>    <span class="k">if</span> <span class="n">max_total_tokens</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
      <span class="k">if</span> <span class="n">max_total_tokens</span> <span class="o">&gt;</span> <span class="n">max_total_num_tokens</span><span class="p">:</span> <span class="n">logger</span><span class="p">.</span><span class="n">warning</span><span class="p">...</span>
      <span class="n">max_total_num_tokens</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_total_tokens</span><span class="p">,</span> <span class="n">max_total_num_tokens</span><span class="p">)</span>

    <span class="c1"># align page size
</span>    <span class="n">max_total_num_tokens</span> <span class="o">=</span> <span class="p">(</span><span class="n">max_total_num_tokens</span> <span class="o">//</span> <span class="n">page_size</span><span class="p">)</span> <span class="o">*</span> <span class="n">page_size</span>

    <span class="c1"># init req_to_token_pool
</span>    <span class="n">req_to_token_pool</span> <span class="o">=</span> <span class="n">ReqToTokenPool</span><span class="p">(</span>
           <span class="n">max_num_reqs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
           <span class="n">model_config</span><span class="p">.</span><span class="n">context_len</span> <span class="o">+</span> <span class="mi">4</span><span class="p">,</span>
           <span class="p">...)</span>

    <span class="c1"># init token_to_kv_pool
</span>    <span class="n">token_to_kv_pool</span> <span class="o">=</span> <span class="n">MHATokenToKVPool</span><span class="p">(</span>
           <span class="n">max_total_num_tokens</span><span class="p">,</span>
           <span class="n">page_size</span><span class="p">,</span>
           <span class="n">kv_cache_dtype</span><span class="p">,</span>
           <span class="n">head_num</span><span class="p">,</span>
           <span class="n">head_dim</span><span class="p">,</span>
           <span class="n">layer_num</span><span class="p">,</span>
           <span class="p">...)</span>

    <span class="c1"># init token_to_kv_pool_allocator
</span>    <span class="n">token_to_kv_pool_allocator</span> <span class="o">=</span> <span class="n">TokenToKVPoolAllocator</span><span class="p">(</span>
        <span class="n">max_total_num_tokens</span><span class="p">,</span>
        <span class="n">kv_cache_dtype</span><span class="p">,</span>
        <span class="n">device</span><span class="p">,</span>
        <span class="n">token_to_kv_pool</span><span class="p">)</span>

    <span class="p">...</span><span class="n">END</span> <span class="err">!!!</span>

  <span class="k">def</span> <span class="nf">profile_max_num_token</span><span class="p">(</span><span class="n">total_gpu_memory</span><span class="p">):</span>
    <span class="c1"># get min_per_gpu_memory in the world
</span>    <span class="c1"># Note: model has been loaded before
</span>    <span class="n">available_gpu_memory</span> <span class="o">=</span> <span class="n">get_available_gpu_memory</span><span class="p">(</span><span class="n">distributed</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

    <span class="c1"># Compute how much gpu memory **a token's KV Cache** occupy
</span>    <span class="c1"># Note: In TP settings, each GPU only handles part of `attention head` when computing attention scores
</span>    <span class="n">cell_size</span> <span class="o">=</span> <span class="p">(</span>
      <span class="n">model_config</span><span class="p">.</span><span class="n">get_num_kv_heads</span><span class="p">(</span><span class="n">get_attention_tp_size</span><span class="p">())</span> <span class="c1"># get how many num_kv_heads in TP setting
</span>     <span class="o">*</span> <span class="n">model_config</span><span class="p">.</span><span class="n">head_dim</span>
     <span class="o">*</span> <span class="n">num_layers</span>
     <span class="o">*</span> <span class="mi">2</span> <span class="c1"># since K and V
</span>     <span class="o">*</span> <span class="n">element_size</span><span class="p">(</span><span class="n">kv_cache_dtype</span><span class="p">)</span> <span class="c1"># bytes for each element of KV Cache Type
</span>    <span class="p">)</span>

    <span class="c1"># This is the **role** of `mem_fraction_static` here
</span>    <span class="c1"># Note:
</span>    <span class="c1"># - `total_gpu_memory` is after initializing the distributed environment, min_per_gpu_memory
</span>    <span class="c1"># - `available_gpu_memory` is after initializing the distbuted environment and loading model, min_per_gpu_memory
</span>    <span class="c1"># - `total_gpu_memory * (1 - mem_fraction_static)`: the other potential GPU memory usage (like `activation` in the forward pass)
</span>    <span class="c1"># - `rest_memory`: Free GPU Memory(after loading model) substracting the other GPU memory, the rest is for `KV Cache`
</span>    <span class="n">rest_memory</span> <span class="o">=</span> <span class="n">available_gpu_memory</span> <span class="o">-</span> <span class="n">total_gpu_memory</span> <span class="o">*</span>
       <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">mem_fraction_static</span><span class="p">)</span>

    <span class="c1"># convert rest_memory from GigeByte back to Byte metric
</span>    <span class="c1"># compute how many tokens' KV cache can be saved
</span>    <span class="n">max_num_tokens</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">rest_memory</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="mi">30</span><span class="p">)</span> <span class="o">//</span> <span class="n">cell_size</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">max_num_tokens</span>
</code></pre></div></div>

<p>Reading from above simplified code reviews, we can see:</p>

<p>💡: <strong>How <code class="language-plaintext highlighter-rouge">mem-fraction-static</code> works in the KV Cache Initiation</strong></p>

<p>The <code class="language-plaintext highlighter-rouge">mem_fraction_static</code> of <code class="language-plaintext highlighter-rouge">GPU memory</code> is used for <code class="language-plaintext highlighter-rouge">model weights</code> and <code class="language-plaintext highlighter-rouge">KV Cache Pool</code>, Use a smaller value if you see out-of-memory errors. But how does the process go?</p>

<ol>
  <li>Get Free GPU Memory  (<code class="language-plaintext highlighter-rouge">M1</code>: total GPU free memory)</li>
  <li>Load model (this occupy some GPU Memory)</li>
  <li>Get Free GPU Memory again (<code class="language-plaintext highlighter-rouge">M2</code>: After Loading Model)</li>
  <li>Compute non-static GPU memory: (<code class="language-plaintext highlighter-rouge">M3 = M1 * (1 - mem_fraction_static)</code> )</li>
  <li>The memory for KV cache Pool: <code class="language-plaintext highlighter-rouge">M2 - M3</code></li>
</ol>

<p>💡: <strong>How is each token’s <code class="language-plaintext highlighter-rouge">KV Cache</code> computed</strong></p>

<p><code class="language-plaintext highlighter-rouge">tp_num_head * head_dim * num_layers * 2 * element_size (torch._utils._element_size(kv_cache_dtype))</code></p>

<h2 id="managers">Managers</h2>

<h3 id="req_to_token_pool">req_to_token_pool</h3>

<p>A memory pool that maps a request to its token locations.</p>

<p>Shape: <code class="language-plaintext highlighter-rouge">max_num_reqs + 1</code>  * <code class="language-plaintext highlighter-rouge">self.model_config.context_len + 4</code></p>

<p>Dtype: <code class="language-plaintext highlighter-rouge">torch.int32</code></p>

<p>Access:</p>

<ul>
  <li>dim0: the concrete <code class="language-plaintext highlighter-rouge">req_idx</code></li>
  <li>dim1: token positions in req (starting from 0, 1, 2…), identify the specific token in the request</li>
  <li><code class="language-plaintext highlighter-rouge">out_cache_loc</code> for token, it points to the KV cache indices associated with the token identified by dim0 and dim1</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ReqToTokenPool</span><span class="p">:</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">max_context_len</span><span class="p">):</span>
    <span class="n">req_to_token</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">max_context_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
    <span class="c1"># record free slots
</span>    <span class="n">free_slots</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">size</span><span class="p">))</span>

  <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="n">values</span><span class="p">):</span>
    <span class="n">req_to_token</span><span class="p">[</span><span class="n">indices</span><span class="p">]</span> <span class="o">=</span> <span class="n">values</span>

  <span class="k">def</span> <span class="nf">available_size</span><span class="p">():</span>
    <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="n">free_slots</span><span class="p">)</span>

  <span class="k">def</span> <span class="nf">alloc</span><span class="p">(</span><span class="n">need_size</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">need_size</span> <span class="o">&gt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">free_slots</span><span class="p">):</span> <span class="k">return</span> <span class="bp">None</span>
    <span class="c1"># directly remove `need_size` slots
</span>    <span class="n">select_index</span> <span class="o">=</span> <span class="n">free_slots</span><span class="p">[:</span><span class="n">need_size</span><span class="p">]</span>
        <span class="n">free_slots</span> <span class="o">=</span> <span class="n">free_slots</span><span class="p">[</span><span class="n">need_size</span><span class="p">:]</span>
        <span class="k">return</span> <span class="n">select_index</span>

    <span class="k">def</span> <span class="nf">free</span><span class="p">(</span><span class="n">free_index</span><span class="p">):</span>
      <span class="n">free_slots</span><span class="p">.</span><span class="n">extend</span><span class="p">(</span><span class="n">free_index</span><span class="p">)</span>

  <span class="k">def</span> <span class="nf">clear</span><span class="p">():</span>
    <span class="n">free_slots</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="token_to_kv_pool">token_to_kv_pool</h3>

<p>A pool that maps <code class="language-plaintext highlighter-rouge">out_cache_loc</code> from <code class="language-plaintext highlighter-rouge">req_token_pool</code> to its real KV Cache data</p>

<p>Mainly maintain the <code class="language-plaintext highlighter-rouge">k_buffer</code> and <code class="language-plaintext highlighter-rouge">v_buffer</code> which has the same shape</p>

<p>Shape(List of <code class="language-plaintext highlighter-rouge">Tensor</code>): <code class="language-plaintext highlighter-rouge">layer_num</code> * List[<code class="language-plaintext highlighter-rouge">Tensor</code>], where each <code class="language-plaintext highlighter-rouge">Tensor</code>: <code class="language-plaintext highlighter-rouge">max_total_num_tokens + page_size</code> * <code class="language-plaintext highlighter-rouge">head_num</code> * <code class="language-plaintext highlighter-rouge">head_dim</code></p>

<p>Access:</p>

<ul>
  <li>dim0: <code class="language-plaintext highlighter-rouge">layer_id</code> identify the specific layer</li>
  <li>dim1: <code class="language-plaintext highlighter-rouge">out_cache_loc</code> identify the specific KV cache indices</li>
  <li>dim2: <code class="language-plaintext highlighter-rouge">head</code></li>
  <li>dim3: <code class="language-plaintext highlighter-rouge">head_dim</code></li>
  <li>value: real KV Cache data</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MHATokenToKVPool</span><span class="p">(</span><span class="n">KVCache</span><span class="p">):</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">page_size</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">head_num</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">,</span> <span class="n">layer_num</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">start_layer</span><span class="p">...):</span>
    <span class="c1"># create real KV Cache buffers
</span>    <span class="n">_create_buffers</span><span class="p">()</span>
    <span class="c1">############
</span>    <span class="c1"># Now, each GPU Memory is nearly exhausted
</span>    <span class="c1">###########
</span>
  <span class="k">def</span> <span class="nf">_create_buffers</span><span class="p">():</span>
    <span class="n">k_buffer</span> <span class="o">=</span> <span class="p">[</span>
                <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span>
                    <span class="p">(</span><span class="n">size</span> <span class="o">+</span> <span class="n">page_size</span><span class="p">,</span> <span class="n">head_num</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">),</span>
                    <span class="n">kv_cache_dtype</span><span class="p">,</span>
                    <span class="n">device</span><span class="p">,</span>
                <span class="p">)</span>
                <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">layer_num</span><span class="p">)</span>
            <span class="p">]</span>
        <span class="n">v_buffer</span> <span class="o">=</span> <span class="p">[</span>
                <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span>
                    <span class="p">(</span><span class="n">size</span> <span class="o">+</span> <span class="n">page_size</span><span class="p">,</span> <span class="n">head_num</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">),</span>
                    <span class="n">kv_cache_dtype</span><span class="p">,</span>
                    <span class="n">device</span><span class="p">,</span>
                <span class="p">)</span>
                <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">layer_num</span><span class="p">)</span>
            <span class="p">]</span>
     <span class="k">def</span> <span class="nf">_clear_buffers</span><span class="p">():</span>
       <span class="k">del</span> <span class="n">k_buffer</span><span class="p">,</span> <span class="n">v_buffer</span>

   <span class="c1">################
</span>   <span class="c1">## READ API
</span>   <span class="c1">################
</span>   <span class="k">def</span> <span class="nf">get_key_buffer</span><span class="p">(</span><span class="n">layer_id</span><span class="p">):</span>
     <span class="k">return</span> <span class="n">k_buffer</span><span class="p">[</span><span class="n">layer_id</span> <span class="o">-</span> <span class="n">start_layer</span><span class="p">]</span>

   <span class="k">def</span> <span class="nf">get_value_buffer</span><span class="p">(</span><span class="n">layer_id</span><span class="p">):</span>
     <span class="k">return</span> <span class="n">v_buffer</span><span class="p">[</span><span class="n">layer_id</span> <span class="o">-</span> <span class="n">start_layer</span><span class="p">]</span>

   <span class="k">def</span> <span class="nf">get_kv_buffer</span><span class="p">(</span><span class="n">layer_id</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">get_key_buffer</span><span class="p">(</span><span class="n">layer_id</span><span class="p">),</span> <span class="n">get_value_buffer</span><span class="p">(</span><span class="n">layer_id</span><span class="p">)</span>

    <span class="c1">############
</span>    <span class="c1">## WRITE API
</span>    <span class="c1">############
</span>    <span class="k">def</span> <span class="nf">set_kv_buffer</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">loc</span><span class="p">,</span> <span class="n">cache_k</span><span class="p">,</span> <span class="n">cache_v</span><span class="p">,</span> <span class="p">...):</span>
      <span class="n">layer_id</span> <span class="o">=</span> <span class="n">layer</span><span class="p">.</span><span class="n">layer_id</span>
      <span class="n">k_buffer</span><span class="p">[</span><span class="n">layer_id</span> <span class="o">-</span> <span class="n">start_layer</span><span class="p">][</span><span class="n">loc</span><span class="p">]</span> <span class="o">=</span> <span class="n">cache_k</span>
         <span class="n">v_buffer</span><span class="p">[</span><span class="n">layer_id</span> <span class="o">-</span> <span class="n">start_layer</span><span class="p">][</span><span class="n">loc</span><span class="p">]</span> <span class="o">=</span> <span class="n">cache_v</span>
</code></pre></div></div>

<h3 id="token_to_kv_pool_allocator">token_to_kv_pool_allocator</h3>

<p>used to allocate real KV Cache data: <code class="language-plaintext highlighter-rouge">out_cache_loc</code></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">TokenToKVPoolAllocator</span><span class="p">:</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">size</span> <span class="p">[</span><span class="n">max_total_num_tokens</span><span class="p">],</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">page_size</span> <span class="n">device</span><span class="p">,</span> <span class="n">kvcache</span> <span class="p">[</span><span class="n">token_to_kvcache_pool</span><span class="p">]):</span>
    <span class="n">page_size</span> <span class="o">=</span> <span class="mi">1</span>
    <span class="n">clear</span><span class="p">()</span>

  <span class="k">def</span> <span class="nf">clear</span><span class="p">():</span>
    <span class="n">free_slots</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">int64</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>

  <span class="k">def</span> <span class="nf">available_size</span><span class="p">():</span>
    <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="n">free_slots</span><span class="p">)</span>

  <span class="c1">##########################
</span>  <span class="c1"># ALLOCATE API
</span>   <span class="c1">#########################
</span>  <span class="k">def</span> <span class="nf">alloc</span><span class="p">(</span><span class="n">need_size</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">need_size</span> <span class="o">&gt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">free_slots</span><span class="p">):</span> <span class="k">return</span> <span class="bp">None</span>
        <span class="n">select_index</span> <span class="o">=</span> <span class="n">free_slots</span><span class="p">[:</span><span class="n">need_size</span><span class="p">]</span>
        <span class="n">free_slots</span> <span class="o">=</span> <span class="n">free_slots</span><span class="p">[</span><span class="n">need_size</span><span class="p">:]</span>
        <span class="k">return</span> <span class="n">select_index</span>

    <span class="c1">###########################
</span>    <span class="c1">## FREE API
</span>    <span class="c1">###########################
</span>    <span class="k">def</span> <span class="nf">free</span><span class="p">(</span><span class="n">free_index</span><span class="p">):</span>
     <span class="n">free_slots</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">((</span><span class="n">free_slots</span><span class="p">,</span> <span class="n">free_index</span><span class="p">))</span>
</code></pre></div></div>

<h2 id="allocate-slots-to-reqs--out_cache_loc">Allocate Slots to Reqs &amp; out_cache_loc</h2>

<p>This raises the question: how does <code class="language-plaintext highlighter-rouge">SGLang</code> use the above managers to efficiently allocate slots for each token in the requests and free them in a timely manner?</p>

<p>LLM inference consists of two main stages. We start by identifying the allocation requirements for each stage.</p>

<ol>
  <li>prefill:
    <ol>
      <li><code class="language-plaintext highlighter-rouge">req_to_token_pool.alloc</code> : since we have new reqs</li>
      <li><code class="language-plaintext highlighter-rouge">token_to_kv_pool_allocator.alloc</code> : Maybe,
        <ol>
          <li>if we have the <code class="language-plaintext highlighter-rouge">kv cache</code> in the tokens in the reqs, we can just use <code class="language-plaintext highlighter-rouge">req_to_token_pool.write</code> to reuse those kv cache</li>
          <li>if we don’t have the <code class="language-plaintext highlighter-rouge">kv cache</code>, then get <code class="language-plaintext highlighter-rouge">out_cache_loc</code> by calling <code class="language-plaintext highlighter-rouge">token_to_kv_pool_allocator.alloc</code> , then write <code class="language-plaintext highlighter-rouge">out_cache_loc</code> into <code class="language-plaintext highlighter-rouge">req_token_pool</code></li>
        </ol>
      </li>
    </ol>
  </li>
  <li>decode:
    <ol>
      <li><code class="language-plaintext highlighter-rouge">req_to_token_pool.alloc</code> : don’t need</li>
      <li><code class="language-plaintext highlighter-rouge">token_to_kv_pool_allocate.alloc</code> Need, since we decode one new token one time</li>
    </ol>
  </li>
</ol>

<p>So in the <code class="language-plaintext highlighter-rouge">scheduler.get_next_batch_to_run</code> where get <code class="language-plaintext highlighter-rouge">ScheduleBatch</code> , for different stage, there are different logics to prepare where allocate and free slots happened.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ScheduleBatch</span><span class="p">:</span>
    <span class="s">"""Store all information of a batch on the scheduler."""</span>

  <span class="k">def</span> <span class="nf">prepare_for_extend</span><span class="p">():</span>
    <span class="n">bs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">reqs</span><span class="p">)</span>
    <span class="n">req_pool_indices</span> <span class="o">=</span> <span class="n">alloc_req_slots</span><span class="p">(</span><span class="n">bs</span><span class="p">)</span>

    <span class="c1"># fill_ids = origin_input_ids + output_ids
</span>    <span class="c1"># input_ids are those token_ids whose KV Cache needs computing
</span>    <span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="p">.</span><span class="n">fill_ids</span><span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">r</span><span class="p">.</span><span class="n">prefix_indices</span><span class="p">):</span> <span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">reqs</span><span class="p">]</span>

    <span class="c1"># this is the num tokens we need allocate slots to accommodate
</span>    <span class="n">extend_num_tokens</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">input_ids</span><span class="p">)</span>

    <span class="n">seq_lens</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">r</span><span class="p">.</span><span class="n">fill_ids</span><span class="p">)</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">reqs</span><span class="p">]</span>
    <span class="n">prefix_lens</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">r</span><span class="p">.</span><span class="n">prefix_indices</span><span class="p">)</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">reqs</span><span class="p">]</span>

    <span class="c1"># extend_lens is actually equal to `seq_lens - prefix_lens`
</span>    <span class="n">extend_lens</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="p">.</span><span class="n">extend_input_len</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">reqs</span><span class="p">]</span>

    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">req</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">pre_len</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">reqs</span><span class="p">,</span> <span class="n">seq_lens</span><span class="p">,</span> <span class="n">pre_lens</span><span class="p">):</span>
      <span class="n">req</span><span class="p">.</span><span class="n">req_pool_idx</span> <span class="o">=</span> <span class="n">req_pool_indices</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>

      <span class="c1"># here assert again
</span>      <span class="k">assert</span> <span class="n">seq_len</span> <span class="o">-</span> <span class="n">pre_len</span> <span class="o">==</span> <span class="n">req</span><span class="p">.</span><span class="n">extend_input_len</span>

      <span class="k">if</span> <span class="n">pre_len</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
        <span class="c1"># write cached `out_cache_loc` into `req_to_token_pool`
</span>        <span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">write</span><span class="p">(</span>
                    <span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">req_pool_idx</span><span class="p">,</span> <span class="nb">slice</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pre_len</span><span class="p">)),</span> <span class="n">req</span><span class="p">.</span><span class="n">prefix_indices</span>
                <span class="p">)</span>

       <span class="n">out_cache_loc</span> <span class="o">=</span> <span class="n">alloc_token_slots</span><span class="p">(</span><span class="n">extend_num_tokens</span><span class="p">)</span>

       <span class="n">pt</span> <span class="o">=</span> <span class="mi">0</span>
       <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">bs</span><span class="p">):</span>
         <span class="c1"># write uncached `out_cache_loc` into `req_to_token_pool`
</span>            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">bs</span><span class="p">):</span>
                <span class="bp">self</span><span class="p">.</span><span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">write</span><span class="p">(</span>
                    <span class="p">(</span><span class="n">req_pool_indices</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="nb">slice</span><span class="p">(</span><span class="n">prefix_lens</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">seq_lens</span><span class="p">[</span><span class="n">i</span><span class="p">])),</span>
                    <span class="n">out_cache_loc</span><span class="p">[</span><span class="n">pt</span> <span class="p">:</span> <span class="n">pt</span> <span class="o">+</span> <span class="n">extend_lens</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span>
                <span class="p">)</span>
                <span class="n">pt</span> <span class="o">+=</span> <span class="n">extend_lens</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
       <span class="p">...</span> <span class="n">END</span> <span class="err">!!!</span>

  <span class="k">def</span> <span class="nf">prepare_for_decode</span><span class="p">():</span>
    <span class="n">bs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">reqs</span><span class="p">)</span>

    <span class="c1"># allocate `bs` tokens
</span>    <span class="n">out_cache_loc</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">alloc_token_slots</span><span class="p">(</span><span class="n">bs</span><span class="p">)</span>

    <span class="c1"># compute `req_to_token_pool` locs
</span>    <span class="n">locs</span> <span class="o">=</span> <span class="n">seq_lens</span> <span class="o">+</span> <span class="mi">1</span>

    <span class="c1"># write
</span>    <span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">write</span><span class="p">(</span>
            <span class="p">(</span><span class="n">req_pool_indices</span><span class="p">,</span> <span class="n">locs</span><span class="p">),</span> <span class="n">out_cache_loc</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
        <span class="p">)</span>
       <span class="p">...</span> <span class="n">END</span> <span class="err">!!!</span>

  <span class="k">def</span> <span class="nf">alloc_req_slots</span><span class="p">(</span><span class="n">num_reqs</span><span class="p">):</span>
    <span class="n">req_pool_indices</span> <span class="o">=</span> <span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">alloc</span><span class="p">(</span><span class="n">num_reqs</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">req_pool_indices</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span> <span class="k">raise</span> <span class="nb">RuntimeError</span><span class="p">(</span><span class="s">""</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">req_pool_indices</span>

  <span class="k">def</span> <span class="nf">alloc_token_slots</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">):</span>
    <span class="n">out_cache_loc</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">token_to_kv_pool_allocator</span><span class="p">.</span><span class="n">alloc</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">out_cache_loc</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span> <span class="k">raise</span> <span class="nb">RuntimeError</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">out_cache_loc</span>

</code></pre></div></div>

<h2 id="read--save-real-kv-cache-data-when-computing-attention-scores">Read &amp; Save Real KV Cache Data when computing Attention Scores</h2>

<p>In model forward, <code class="language-plaintext highlighter-rouge">model_runner</code> will call <code class="language-plaintext highlighter-rouge">attention_backnend.init_forward_metadata</code> to initialize the metadata for the attention backend and then call the actual <code class="language-plaintext highlighter-rouge">forward_extend</code> and <code class="language-plaintext highlighter-rouge">forward_decode</code></p>

<p>during the <code class="language-plaintext highlighter-rouge">init_forward_metadata</code> , by use <code class="language-plaintext highlighter-rouge">req_to_token_pool.req_to_token</code> , we get the <code class="language-plaintext highlighter-rouge">page table</code> which is then used in each layer’s attention score computation</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">FlashAttentionBackend</span><span class="p">(</span><span class="n">AttentionBackend</span><span class="p">):</span>
  <span class="k">def</span> <span class="nf">init_forward_metadata</span><span class="p">(</span><span class="n">forward_batch</span><span class="p">):</span>
    <span class="n">metadata</span> <span class="o">=</span> <span class="n">FlashAttentionMetadata</span><span class="p">()</span>
    <span class="k">if</span> <span class="n">forward_batch</span><span class="p">.</span><span class="n">is_decode</span><span class="p">():</span>
      <span class="n">metadata</span><span class="p">.</span><span class="n">max_seq_len_k</span> <span class="o">=</span> <span class="n">forward_batch</span><span class="p">.</span><span class="n">seq_lens_cpu</span><span class="p">.</span><span class="nb">max</span><span class="p">().</span><span class="n">item</span><span class="p">()</span>
      <span class="c1"># get the page table!
</span>      <span class="n">metadata</span><span class="p">.</span><span class="n">page_table</span> <span class="o">=</span> <span class="n">forward_batch</span><span class="p">.</span><span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">req_to_token</span><span class="p">[</span>
                 <span class="n">forward_batch</span><span class="p">.</span><span class="n">req_pool_indices</span><span class="p">,</span> <span class="p">:</span> <span class="n">metadata</span><span class="p">.</span><span class="n">max_seq_len_k</span>
             <span class="p">]</span>
     <span class="k">elif</span> <span class="n">forward_batch</span><span class="p">.</span><span class="n">is_extend</span><span class="p">():</span>
       <span class="c1"># ... nearly same ...
</span></code></pre></div></div>

<p><code class="language-plaintext highlighter-rouge">save &amp; retrieve</code> process takes place at the model forward, where <code class="language-plaintext highlighter-rouge">attention_backend.forward_extend</code> or <code class="language-plaintext highlighter-rouge">attention_backend.forward_extend</code></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">FlashAttention</span><span class="p">(</span><span class="n">AttentionBackend</span><span class="p">):</span>
  <span class="k">def</span> <span class="nf">forward_extend</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">layer</span><span class="p">,</span> <span class="n">forward_batch</span><span class="p">,</span> <span class="n">save_kv_cache</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="p">...):</span>
    <span class="k">if</span> <span class="n">k</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
      <span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
        <span class="n">cache_loc</span> <span class="o">=</span> <span class="n">forward_batch</span><span class="p">.</span><span class="n">out_cache_loc</span>

        <span class="c1"># !!! Save the KV Cache into token_to_kv_pool !!!
</span>        <span class="n">forward_batch</span><span class="p">.</span><span class="n">token_to_kv_pool</span><span class="p">.</span><span class="n">set_kv_buffer</span><span class="p">(</span>
                        <span class="n">layer</span><span class="p">,</span> <span class="n">cache_loc</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="p">...</span>
                    <span class="p">)</span>
       <span class="c1"># Use precomputed metadata across all layers
</span>        <span class="c1"># prepare metedata for FlashAttention operator
</span>        <span class="n">metadata</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">forward_metadata</span>
        <span class="n">page_table</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">.</span><span class="n">page_table</span>
        <span class="n">cu_seqlens_q</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">.</span><span class="n">cu_seqlens_q</span>
        <span class="n">cache_seqlens</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">.</span><span class="n">cache_seqlens_int32</span>
        <span class="n">max_seqlen_q</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">.</span><span class="n">max_seq_len_q</span>
        <span class="n">max_seqlen_k</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">.</span><span class="n">max_seq_len_k</span>
        <span class="n">cu_seqlens_k</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">.</span><span class="n">cu_seqlens_k</span>

        <span class="c1"># !!! Retrive the KV Cache from token_to_kv_pool !!!
</span>        <span class="n">key_cache</span><span class="p">,</span> <span class="n">value_cache</span> <span class="o">=</span> <span class="n">forward_batch</span><span class="p">.</span><span class="n">token_to_kv_pool</span><span class="p">.</span><span class="n">get_kv_buffer</span><span class="p">(</span>
                <span class="n">layer</span><span class="p">.</span><span class="n">layer_id</span>
            <span class="p">)</span>
        <span class="c1"># review the format
</span>        <span class="n">key_cache</span> <span class="o">=</span> <span class="n">key_cache</span><span class="p">.</span><span class="n">view</span><span class="p">(</span>
                <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">page_size</span><span class="p">,</span> <span class="n">layer</span><span class="p">.</span><span class="n">tp_k_head_num</span><span class="p">,</span> <span class="n">layer</span><span class="p">.</span><span class="n">head_dim</span>
            <span class="p">)</span>
        <span class="n">value_cache</span> <span class="o">=</span> <span class="n">value_cache</span><span class="p">.</span><span class="n">view</span><span class="p">(</span>
                <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">page_size</span><span class="p">,</span> <span class="n">layer</span><span class="p">.</span><span class="n">tp_v_head_num</span><span class="p">,</span> <span class="n">layer</span><span class="p">.</span><span class="n">head_dim</span>
            <span class="p">)</span>

        <span class="n">result</span> <span class="o">=</span> <span class="n">flash_attn_with_kvcache</span><span class="p">(</span>
          <span class="n">q</span><span class="o">=</span><span class="n">q</span><span class="p">.</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">layer</span><span class="p">.</span><span class="n">tp_q_head_num</span><span class="p">,</span> <span class="n">layer</span><span class="p">.</span><span class="n">head_dim</span><span class="p">),</span>
          <span class="n">key_cache</span><span class="p">,</span>
          <span class="n">value_cache</span><span class="p">,</span>
          <span class="n">page_table</span><span class="p">,</span>
          <span class="p">...</span>
       <span class="p">)</span>

       <span class="k">return</span> <span class="n">o</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">layer</span><span class="p">.</span><span class="n">tp_q_head_num</span> <span class="o">*</span> <span class="n">layer</span><span class="p">.</span><span class="n">v_head_dim</span><span class="p">)</span>

  <span class="k">def</span> <span class="nf">forward_decode</span><span class="p">(</span><span class="n">forward_batch</span><span class="p">):</span>
    <span class="c1"># ... nearly same to forward_extend ...
</span></code></pre></div></div>

<p>The first section <code class="language-plaintext highlighter-rouge">KV Cache Management</code> is over here, we talked about</p>

<ol>
  <li>How <code class="language-plaintext highlighter-rouge">KV Cache</code> are initiate: Just create a List of Huge Tensors</li>
  <li>How <code class="language-plaintext highlighter-rouge">KV Cache</code> is manged (allocate <code class="language-plaintext highlighter-rouge">slots</code>, <code class="language-plaintext highlighter-rouge">tokens</code> to reqs)</li>
  <li>How the real <code class="language-plaintext highlighter-rouge">KV Cache data</code> are saved and retrieved when computing attention scores</li>
</ol>

<h1 id="radix-tree-cache">Radix Tree Cache</h1>

<p>One novel idea of <code class="language-plaintext highlighter-rouge">SGLang</code> is <code class="language-plaintext highlighter-rouge">Radix Attention</code> , which uses <code class="language-plaintext highlighter-rouge">radix tree</code> to reuse <code class="language-plaintext highlighter-rouge">KV Cache</code> as much as possible.</p>

<p>So, what is <code class="language-plaintext highlighter-rouge">Radix Tree</code>?</p>

<p>Its core idea is to get reuseable <code class="language-plaintext highlighter-rouge">out_cache_loc</code> based on the <code class="language-plaintext highlighter-rouge">token_ids</code>, <code class="language-plaintext highlighter-rouge">token_ids</code> is the key in the tree node. So for the requests with the same prefix <code class="language-plaintext highlighter-rouge">token_ids</code>, we search in the tree to get its <code class="language-plaintext highlighter-rouge">out_cache_loc</code>, by doing so, we can use <strong>one</strong> <code class="language-plaintext highlighter-rouge">out_cache_loc</code> for <strong>two or more</strong> requests.</p>

<h2 id="radix-tree">Radix Tree</h2>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">TreeNode</span><span class="p">:</span>

    <span class="n">counter</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">id</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">children</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="n">TreeNode</span><span class="p">)</span> <span class="c1"># use 1 page-size key as the dict_key
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">parent</span> <span class="o">=</span> <span class="bp">None</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">key</span> <span class="o">=</span> <span class="bp">None</span> <span class="c1"># Key is the `token_ids`
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">value</span> <span class="o">=</span> <span class="bp">None</span> <span class="c1"># Value is the `out_cache_loc`, which records the location of real KV Cache data
</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">lock_ref</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># how many reqs reference this node
</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">last_access_time</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">monotonic</span><span class="p">()</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">hit_count</span> <span class="o">=</span> <span class="mi">0</span>

        <span class="c1"># indicating the node is loading KV cache from host
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">loading</span> <span class="o">=</span> <span class="bp">False</span>

        <span class="c1"># store the host indices of KV cache
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">host_value</span> <span class="o">=</span> <span class="bp">None</span>

        <span class="bp">self</span><span class="p">.</span><span class="nb">id</span> <span class="o">=</span> <span class="n">TreeNode</span><span class="p">.</span><span class="n">counter</span> <span class="k">if</span> <span class="nb">id</span> <span class="ow">is</span> <span class="bp">None</span> <span class="k">else</span> <span class="nb">id</span>
        <span class="n">TreeNode</span><span class="p">.</span><span class="n">counter</span> <span class="o">+=</span> <span class="mi">1</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">RadixTree</span><span class="p">(</span><span class="n">BasePrefixCache</span><span class="p">):</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">req_to_token_pool</span><span class="p">,</span> <span class="n">token_to_kv_pool_allocator</span><span class="p">,</span> <span class="n">page_size</span><span class="p">,</span> <span class="p">...):</span>
    <span class="k">if</span> <span class="n">page_size</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
      <span class="c1"># key_match_fn: given 2 keys, return how many prefix ids that two keys has
</span>            <span class="n">key_match_fn</span> <span class="o">=</span> <span class="n">_key_match_page_size1</span>

            <span class="c1"># get_child_key_fn: get 1-page-size key
</span>            <span class="n">get_child_key_fn</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">key</span><span class="p">:</span> <span class="n">key</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">key_match_fn</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">_key_match_paged</span><span class="p">,</span> <span class="n">page_size</span><span class="o">=</span><span class="n">page_size</span><span class="p">)</span>
            <span class="n">get_child_key_fn</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">key</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">key</span><span class="p">[:</span><span class="n">page_size</span><span class="p">])</span>
    <span class="n">reset</span><span class="p">()</span>

  <span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">root_node</span> <span class="o">=</span> <span class="n">TreeNode</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">root_node</span><span class="p">.</span><span class="n">key</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">root_node</span><span class="p">.</span><span class="n">value</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">root_node</span><span class="p">.</span><span class="n">lock_ref</span> <span class="o">=</span> <span class="mi">1</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">evictable_size_</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">protected_size_</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">_record_all_cleared_event</span><span class="p">()</span>
</code></pre></div></div>

<h3 id="match">Match</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="c1">########################
</span>   <span class="c1"># Match Prefix
</span>   <span class="c1">########################
</span>   <span class="k">def</span> <span class="nf">match_prefix</span><span class="p">(</span><span class="n">key</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span>
     <span class="n">page_aligned_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="o">//</span> <span class="n">page_size</span> <span class="o">*</span> <span class="n">page_size</span>
       <span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="p">[:</span><span class="n">page_aligned_len</span><span class="p">]</span>

       <span class="n">value</span><span class="p">,</span> <span class="n">last_node</span> <span class="o">=</span> <span class="n">_match_prefix_helper</span><span class="p">(</span><span class="n">root_node</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span>
       <span class="k">if</span> <span class="n">value</span><span class="p">:</span> <span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
       <span class="k">else</span><span class="p">:</span> <span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">empty</span><span class="p">((</span><span class="mi">0</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">int64</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>

       <span class="c1"># 1. prefix `out_cache_loc` in the radix tree
</span>       <span class="c1"># 2. last_node
</span>      <span class="k">return</span> <span class="n">value</span><span class="p">,</span> <span class="n">last_node</span>

  <span class="k">def</span> <span class="nf">_match_prefix_helper</span><span class="p">(</span><span class="n">node</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
    <span class="c1"># update time
</span>    <span class="n">node</span><span class="p">.</span><span class="n">last_access_time</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">monotonic</span><span class="p">()</span>

    <span class="c1"># get child key first
</span>    <span class="n">child_key</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">get_child_key_fn</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>

    <span class="n">value</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">child_key</span> <span class="ow">in</span> <span class="n">node</span><span class="p">.</span><span class="n">children</span><span class="p">.</span><span class="n">keys</span><span class="p">():</span>

      <span class="n">child</span> <span class="o">=</span> <span class="n">node</span><span class="p">.</span><span class="n">children</span><span class="p">[</span><span class="n">child_key</span><span class="p">]</span>

      <span class="c1"># update time
</span>      <span class="n">child</span><span class="p">.</span><span class="n">last_access_time</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">monotonic</span><span class="p">()</span>

      <span class="c1"># get how many number of prefix ids (n * page_size)
</span>      <span class="n">prefix_len</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">key_match_fn</span><span class="p">(</span><span class="n">child</span><span class="p">.</span><span class="n">key</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span>

      <span class="k">if</span> <span class="n">prefix_len</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">child</span><span class="p">.</span><span class="n">key</span><span class="p">):</span>
        <span class="c1"># not a full match, split a full match, but shorter new_node
</span>
        <span class="c1"># NOTE: prefix_len is at least 1-page-size since `child_key in node.children.keys()`
</span>        <span class="n">new_node</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_split_node</span><span class="p">(</span><span class="n">child</span><span class="p">.</span><span class="n">key</span><span class="p">,</span> <span class="n">child</span><span class="p">,</span> <span class="n">prefix_len</span><span class="p">)</span>

        <span class="c1"># append the matched value
</span>        <span class="n">value</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">new_node</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>
               <span class="n">node</span> <span class="o">=</span> <span class="n">new_node</span>
               <span class="k">break</span>
      <span class="k">else</span><span class="p">:</span>
        <span class="c1"># full match, try to get next child
</span>
        <span class="c1"># save the value
</span>        <span class="n">value</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">child</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>

        <span class="c1"># update the node
</span>               <span class="n">node</span> <span class="o">=</span> <span class="n">child</span>

               <span class="c1"># truncate the prefix matched keys
</span>               <span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="p">[</span><span class="n">prefix_len</span><span class="p">:]</span>

               <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">key</span><span class="p">):</span>
                 <span class="n">child_key</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">get_child_key_fn</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
       <span class="k">return</span> <span class="n">value</span><span class="p">,</span> <span class="n">node</span>
</code></pre></div></div>

<h3 id="split-node">Split Node</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="c1">#############
</span>   <span class="c1"># Split Node
</span>   <span class="c1">#############
</span>  <span class="k">def</span> <span class="nf">_split_node</span><span class="p">(</span><span class="n">key</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">child</span><span class="p">,</span> <span class="n">split_len</span><span class="p">):</span>
    <span class="c1"># here, key is actually child's key
</span>    <span class="c1"># key and value will be split into two parts
</span>    <span class="c1"># key and value: [......................... | ..........................]
</span>    <span class="c1">#                                       prefix_len
</span>    <span class="c1">#                  left: a new node's kv        right: truncated child
</span>    <span class="c1"># after this split process, `child(node)` will be
</span>    <span class="c1"># `parent &lt;-&gt; child`    =&gt;
</span>    <span class="c1"># `parent &lt;-&gt; new_node &lt;-&gt; truncated child`
</span>
    <span class="c1"># create a new node
</span>    <span class="n">new_node</span> <span class="o">=</span> <span class="n">TreeNode</span><span class="p">()</span>

    <span class="c1"># make `new_node ---truncated child's 1-page-size key---&gt; child`
</span>    <span class="n">new_node</span><span class="p">.</span><span class="n">children</span> <span class="o">=</span> <span class="p">{</span><span class="bp">self</span><span class="p">.</span><span class="n">get_child_key_fn</span><span class="p">(</span><span class="n">key</span><span class="p">[</span><span class="n">split_len</span><span class="p">:]):</span> <span class="n">child</span><span class="p">}</span>

       <span class="c1"># make `parent -&gt; new_node`
</span>       <span class="n">new_node</span><span class="p">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">child</span><span class="p">.</span><span class="n">parent</span>

       <span class="c1"># make new_node get the same ref count
</span>       <span class="n">new_node</span><span class="p">.</span><span class="n">lock_ref</span> <span class="o">=</span> <span class="n">child</span><span class="p">.</span><span class="n">lock_ref</span>

       <span class="c1"># get left side kv, and set them to new_node
</span>       <span class="n">new_node</span><span class="p">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">child</span><span class="p">.</span><span class="n">key</span><span class="p">[:</span><span class="n">split_len</span><span class="p">]</span>
       <span class="n">new_node</span><span class="p">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">child</span><span class="p">.</span><span class="n">value</span><span class="p">[:</span><span class="n">split_len</span><span class="p">]</span>

    <span class="c1"># make `new_node &lt;- child`
</span>       <span class="n">child</span><span class="p">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">new_node</span>

       <span class="c1"># make `child` become `truncated child`: truncate the split_len key and value
</span>       <span class="n">child</span><span class="p">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">child</span><span class="p">.</span><span class="n">key</span><span class="p">[</span><span class="n">split_len</span><span class="p">:]</span>
       <span class="n">child</span><span class="p">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">child</span><span class="p">.</span><span class="n">value</span><span class="p">[</span><span class="n">split_len</span><span class="p">:]</span>

       <span class="c1"># make `parent ----new_node's 1-page-size key---&gt; new_node
</span>       <span class="n">new_node</span><span class="p">.</span><span class="n">parent</span><span class="p">.</span><span class="n">children</span><span class="p">[</span><span class="bp">self</span><span class="p">.</span><span class="n">get_child_key_fn</span><span class="p">(</span><span class="n">key</span><span class="p">)]</span> <span class="o">=</span> <span class="n">new_node</span>

    <span class="k">return</span> <span class="n">new_node</span>
</code></pre></div></div>

<h3 id="insert-node">Insert Node</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="c1">################
</span> <span class="c1"># Insert Node
</span> <span class="c1">################
</span> <span class="k">def</span> <span class="nf">insert</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">List</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
     <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">disable</span><span class="p">:</span> <span class="k">return</span> <span class="mi">0</span>

     <span class="k">if</span> <span class="n">value</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span> <span class="n">value</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">key</span><span class="p">]</span>

     <span class="k">return</span> <span class="n">_insert_helper</span><span class="p">(</span><span class="n">root_node</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>

  <span class="k">def</span> <span class="nf">_insert_helper</span><span class="p">(</span><span class="n">node</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
    <span class="c1"># update node's time for LRU eviction
</span>    <span class="n">node</span><span class="p">.</span><span class="n">last_access_time</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">monotonic</span><span class="p">()</span>

      <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="k">return</span> <span class="mi">0</span>

      <span class="c1"># get 1-page-size key used for searching prefix
</span>      <span class="n">child_key</span> <span class="o">=</span> <span class="n">get_child_key_fn</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>

      <span class="n">total_prefix_length</span> <span class="o">=</span> <span class="mi">0</span>

      <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">child_key</span> <span class="ow">in</span> <span class="n">node</span><span class="p">.</span><span class="n">children</span><span class="p">.</span><span class="n">keys</span><span class="p">():</span>
      <span class="c1"># get next node
</span>      <span class="n">node</span> <span class="o">=</span> <span class="n">node</span><span class="p">.</span><span class="n">children</span><span class="p">[</span><span class="n">child_key</span><span class="p">]</span>
      <span class="c1"># update next node's time
</span>      <span class="n">node</span><span class="p">.</span><span class="n">last_access_time</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">monotonic</span><span class="p">()</span>

      <span class="c1"># get prefix_len of next node and query key
</span>      <span class="n">prefix_len</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">key_match_fn</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">key</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span>

      <span class="n">total_prefix_length</span> <span class="o">+=</span> <span class="n">prefix_len</span>

      <span class="c1"># update key and value
</span>      <span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="p">[</span><span class="n">prefix_len</span><span class="p">:]</span>
          <span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="p">[</span><span class="n">prefix_len</span><span class="p">:]</span>

          <span class="k">if</span> <span class="n">prefix_len</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">key</span><span class="p">):</span>
            <span class="c1"># not a full match, split the node
</span>            <span class="n">new_node</span> <span class="o">=</span> <span class="n">_split_node</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">key</span><span class="p">,</span> <span class="n">node</span><span class="p">,</span> <span class="n">prefix_len</span><span class="p">)</span>

              <span class="n">node</span> <span class="o">=</span> <span class="n">new_node</span>

          <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">key</span><span class="p">):</span>
            <span class="c1"># there are still some keys hasn't been matched, try to continue to find next node
</span>            <span class="n">child_key</span> <span class="o">=</span> <span class="n">get_child_key_fn</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>

            <span class="c1"># NOTE: if prefix_len &lt; len(node.key)
</span>            <span class="c1"># then it is impossible to continue this while loop
</span>            <span class="c1"># because the splitted new node only have one child, which is the unmatched node
</span>            <span class="c1"># so this new `child_key` doesn't exist `node.children.keys()`
</span>            <span class="c1"># this while loop continues only if a full match, but the query key still has a remaining part
</span>
   <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">key</span><span class="p">):</span>
     <span class="c1"># if there exists still a remaining key that doesn't match in this radix tree,
</span>     <span class="c1"># create a new node
</span>     <span class="c1"># NOTE: this new node's lock_ref is 0, so it deems evictable
</span>     <span class="n">new_node</span> <span class="o">=</span> <span class="n">TreeNode</span><span class="p">()</span>
          <span class="n">new_node</span><span class="p">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">node</span>
          <span class="n">new_node</span><span class="p">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">key</span>
          <span class="n">new_node</span><span class="p">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">value</span>

          <span class="c1"># make node` point to this `new_node`
</span>          <span class="n">node</span><span class="p">.</span><span class="n">children</span><span class="p">[</span><span class="n">child_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_node</span>

          <span class="c1"># this is evictable since it is a leaf node
</span>          <span class="n">evictable_size_</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>

   <span class="k">return</span> <span class="n">total_prefix_length</span>
</code></pre></div></div>

<h3 id="lock-ref">Lock Ref</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
 <span class="c1">##################
</span> <span class="c1"># Handle Lock Ref
</span> <span class="c1">##################
</span>  <span class="k">def</span> <span class="nf">dec_lock_ref</span><span class="p">(</span><span class="n">node</span><span class="p">):</span>
   <span class="k">if</span> <span class="n">disable</span><span class="p">:</span> <span class="k">return</span> <span class="mi">0</span> <span class="c1"># if disable radix tree
</span>   <span class="n">delta</span> <span class="o">=</span> <span class="mi">0</span>

   <span class="c1"># bottom to up
</span>   <span class="k">while</span> <span class="n">node</span> <span class="o">!=</span> <span class="n">root_node</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">node</span><span class="p">.</span><span class="n">lock_ref</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
       <span class="c1"># if there is only 1 ref to this node, this node deems evictable
</span>           <span class="n">evictable_size_</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>
             <span class="n">protected_size_</span> <span class="o">-=</span> <span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>
             <span class="n">delta</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>
         <span class="n">lock_ref</span> <span class="o">-=</span> <span class="mi">1</span>
         <span class="n">node</span> <span class="o">=</span> <span class="n">node</span><span class="p">.</span><span class="n">parent</span>
    <span class="k">return</span> <span class="n">delta</span>

 <span class="k">def</span> <span class="nf">inc_lock_ref</span><span class="p">(</span><span class="n">node</span><span class="p">):</span>
   <span class="k">if</span> <span class="n">disable</span><span class="p">:</span> <span class="k">return</span> <span class="mi">0</span>
   <span class="n">delta</span> <span class="o">=</span> <span class="mi">0</span>

   <span class="c1"># bottom to up
</span>   <span class="k">while</span> <span class="n">node</span> <span class="o">!=</span> <span class="n">root_node</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">node</span><span class="p">.</span><span class="n">lock_ref</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
       <span class="c1"># if no other req ref this node, this node turns evictable to protectable
</span>       <span class="n">evictable_size_</span> <span class="o">-=</span> <span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>
             <span class="bp">self</span><span class="p">.</span><span class="n">protected_size_</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>
             <span class="n">delta</span> <span class="o">-=</span> <span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>
     <span class="n">node</span><span class="p">.</span><span class="n">lock_ref</span> <span class="o">+=</span> <span class="mi">1</span>
   <span class="k">return</span> <span class="n">delta</span>
</code></pre></div></div>

<h3 id="api">API</h3>

<ul>
  <li>Cache when request finished or unfished</li>
  <li>Evcit</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="c1">#######################
</span> <span class="c1"># Cache Unfinished Req
</span>  <span class="c1">#######################
</span>  <span class="k">def</span> <span class="nf">cache_unfinished_req</span><span class="p">(</span><span class="n">req</span><span class="p">):</span>
    <span class="n">token_ids</span> <span class="o">=</span> <span class="n">req</span><span class="p">.</span><span class="n">fill_ids</span>

    <span class="c1"># get `out_cache_loc`, which is actually Value
</span>    <span class="n">kv_indices</span> <span class="o">=</span> <span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">req_to_token</span><span class="p">[</span>
            <span class="n">req</span><span class="p">.</span><span class="n">req_pool_idx</span><span class="p">,</span> <span class="p">:</span> <span class="nb">len</span><span class="p">(</span><span class="n">token_ids</span><span class="p">)</span>
      <span class="p">]</span>

      <span class="k">if</span> <span class="n">page_size</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
        <span class="n">page_aligned_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">kv_indices</span><span class="p">)</span> <span class="o">//</span> <span class="n">page_size</span> <span class="o">*</span> <span class="n">page_size</span>
        <span class="c1"># V align
</span>          <span class="n">page_aligned_kv_indices</span> <span class="o">=</span> <span class="n">kv_indices</span><span class="p">[:</span><span class="n">page_aligned_len</span><span class="p">].</span><span class="n">clone</span><span class="p">()</span>
      <span class="k">else</span><span class="p">:</span>
          <span class="n">page_aligned_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">kv_indices</span><span class="p">)</span>
          <span class="n">page_aligned_kv_indices</span> <span class="o">=</span> <span class="n">kv_indices</span><span class="p">.</span><span class="n">clone</span><span class="p">()</span>

      <span class="c1"># K align
</span>      <span class="n">page_aligned_token_ids</span> <span class="o">=</span> <span class="n">token_ids</span><span class="p">[:</span><span class="n">page_aligned_len</span><span class="p">]</span>

      <span class="c1"># insert K,V
</span>      <span class="n">new_prefix_len</span> <span class="o">=</span> <span class="n">insert</span><span class="p">(</span><span class="n">page_aligned_token_ids</span><span class="p">,</span> <span class="n">page_aligned_kv_indices</span><span class="p">)</span>

      <span class="c1"># remove repetive part
</span>      <span class="n">token_to_kv_pool_allocator</span><span class="p">.</span><span class="n">free</span><span class="p">(</span>
            <span class="n">kv_indices</span><span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">prefix_indices</span><span class="p">)</span> <span class="p">:</span> <span class="n">new_prefix_len</span><span class="p">]</span>
      <span class="p">)</span>

      <span class="c1">#  get prefixed `out_cache_loc` and `new_last_node`
</span>      <span class="n">new_indices</span><span class="p">,</span> <span class="n">new_last_node</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">match_prefix</span><span class="p">(</span><span class="n">page_aligned_token_ids</span><span class="p">)</span>

      <span class="c1"># only write new `out_cache_loc`
</span>      <span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">write</span><span class="p">(</span>
            <span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">req_pool_idx</span><span class="p">,</span> <span class="nb">slice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">prefix_indices</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_indices</span><span class="p">))),</span>
            <span class="n">new_indices</span><span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">prefix_indices</span><span class="p">)</span> <span class="p">:],</span>
      <span class="p">)</span>

      <span class="c1"># root -&gt; ... -&gt; last_node -&gt; ... -&gt; new_last_node
</span>      <span class="c1"># |-- lock_ref - 1 --|
</span>      <span class="n">dec_lock_ref</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">last_node</span><span class="p">)</span>

      <span class="c1"># root -&gt; ... -&gt; last_node -&gt; ... -&gt; new_last_node
</span>      <span class="c1"># |------------- lock_ref + 1 -----------------|
</span>      <span class="n">inc_lock_ref</span><span class="p">(</span><span class="n">new_last_node</span><span class="p">)</span>


 <span class="c1">#####################
</span> <span class="c1"># Cache Finished Req
</span> <span class="c1">#####################
</span>  <span class="k">def</span> <span class="nf">cache_finished_req</span><span class="p">(</span><span class="n">req</span><span class="p">):</span>
   <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">disable</span><span class="p">:</span>
     <span class="c1"># if disable radix tree, free the KV Cache of this finished req directly
</span>
     <span class="c1"># get `out_cache_loc`
</span>     <span class="n">kv_indices</span> <span class="o">=</span> <span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">req_to_token</span><span class="p">[</span>
              <span class="n">req</span><span class="p">.</span><span class="n">req_pool_idx</span><span class="p">,</span> <span class="p">:</span> <span class="nb">len</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">origin_input_ids</span><span class="p">)</span> <span class="o">+</span> <span class="nb">len</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">output_ids</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
          <span class="p">]</span>

          <span class="c1"># free `req slots` and `token_to_kv_pool slots`
</span>          <span class="n">token_to_kv_pool_allocator</span><span class="p">.</span><span class="n">free</span><span class="p">(</span><span class="n">kv_indices</span><span class="p">)</span>
          <span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">free</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">req_pool_idx</span><span class="p">)</span>
          <span class="k">return</span>

     <span class="c1"># if using radix tree, don't free KV Cache instantly for reusing opportunities
</span>
     <span class="c1"># get token_ids, which is actually key
</span>     <span class="n">token_ids</span> <span class="o">=</span> <span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">origin_input_ids</span> <span class="o">+</span> <span class="n">req</span><span class="p">.</span><span class="n">output_ids</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>

     <span class="c1"># get `out_cache_loc`, which is actually value
</span>     <span class="n">kv_indices</span> <span class="o">=</span> <span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">req_to_token</span><span class="p">[</span>
        <span class="n">req</span><span class="p">.</span><span class="n">req_pool_idx</span><span class="p">,</span> <span class="p">:</span> <span class="nb">len</span><span class="p">(</span><span class="n">token_ids</span><span class="p">)</span>
    <span class="p">]</span>

    <span class="c1"># assuming page size is 1, so it is automatically aligned
</span>    <span class="n">page_aligned_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">kv_indices</span><span class="p">)</span>
     <span class="n">page_aligned_kv_indices</span> <span class="o">=</span> <span class="n">kv_indices</span><span class="p">.</span><span class="n">clone</span><span class="p">()</span>

    <span class="c1"># insert the [token_ids, out_cache_loc] into radix tree for reuse
</span>    <span class="n">new_prefix_len</span> <span class="o">=</span> <span class="n">insert</span><span class="p">(</span>
         <span class="n">token_ids</span><span class="p">[:</span><span class="n">page_aligned_len</span><span class="p">],</span> <span class="n">page_aligned_kv_indices</span>
    <span class="p">)</span>

     <span class="c1"># only free [len(prefix_indices): new_prefix_len] part of kv pool, why?
</span>     <span class="c1"># since these part of `out_cache_loc` are REPETITIVE (REDUNDANT)!
</span>
     <span class="c1"># The whole process is as follows:
</span>     <span class="c1"># `req.prefix_indices` is computed when it is scheduled at first
</span>     <span class="c1"># `new_prefix_len` is the prefix lens when it is finished
</span>     <span class="c1"># [len(req.prefix_indices): new_prefix_len] is the repetive part during which computed
</span>    <span class="n">token_to_kv_pool_allocator</span><span class="p">.</span><span class="n">free</span><span class="p">(</span>
          <span class="n">kv_indices</span><span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">prefix_indices</span><span class="p">)</span> <span class="p">:</span> <span class="n">new_prefix_len</span><span class="p">]</span>
     <span class="p">)</span>

     <span class="c1"># free `req slot` for sure
</span>     <span class="c1"># since the req has been finished, its req_pool_idx can be used for other reqs
</span>     <span class="n">req_to_token_pool</span><span class="p">.</span><span class="n">free</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">req_pool_idx</span><span class="p">)</span>

     <span class="c1"># dec lock_ref of those node owns out_cache_loc[:len(prefix_indices)]
</span>     <span class="c1"># these part will be possibly evictable
</span>     <span class="c1"># but Note: these `out_cache_loc` have not been evicted yet
</span>     <span class="n">dec_lock_ref</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">last_node</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="k">def</span> <span class="nf">evict</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">disable</span><span class="p">:</span> <span class="k">return</span>

    <span class="n">leaves</span> <span class="o">=</span> <span class="n">_collect_leaves</span><span class="p">()</span>

    <span class="c1"># sort by `last_access_time` (LRU)
</span>    <span class="n">heapq</span><span class="p">.</span><span class="n">heapify</span><span class="p">(</span><span class="n">leaves</span><span class="p">)</span>

    <span class="n">num_evicted</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="k">while</span> <span class="n">num_evicted</span> <span class="o">&lt;</span> <span class="n">num_tokens</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">leaves</span><span class="p">):</span>
      <span class="n">x</span> <span class="o">=</span> <span class="n">heapq</span><span class="p">.</span><span class="n">heappop</span><span class="p">(</span><span class="n">leaves</span><span class="p">)</span>
      <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="bp">self</span><span class="p">.</span><span class="n">root_node</span><span class="p">:</span> <span class="k">break</span>

      <span class="c1"># if some reqs are pointing to this node, skip it
</span>            <span class="k">if</span> <span class="n">x</span><span class="p">.</span><span class="n">lock_ref</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span> <span class="k">continue</span>

            <span class="c1"># free this node's `out_cache_loc`
</span>            <span class="n">token_to_kv_pool_allocator</span><span class="p">.</span><span class="n">free</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>

            <span class="n">num_evicted</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">value</span><span class="p">)</span>
            <span class="n">_delete_leaf</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

            <span class="c1"># add new leaves node for next evitable
</span>            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">parent</span><span class="p">.</span><span class="n">children</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">heapq</span><span class="p">.</span><span class="n">heappush</span><span class="p">(</span><span class="n">leaves</span><span class="p">,</span> <span class="n">x</span><span class="p">.</span><span class="n">parent</span><span class="p">)</span>

  <span class="k">def</span> <span class="nf">_delete_leaf</span><span class="p">(</span><span class="n">node</span><span class="p">):</span>

    <span class="c1"># delete this node from its parent
</span>    <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">node</span><span class="p">.</span><span class="n">parent</span><span class="p">.</span><span class="n">children</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
            <span class="k">if</span> <span class="n">v</span> <span class="o">==</span> <span class="n">node</span><span class="p">:</span>
                <span class="k">break</span>
        <span class="k">del</span> <span class="n">node</span><span class="p">.</span><span class="n">parent</span><span class="p">.</span><span class="n">children</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>

        <span class="c1"># update evicatble_size
</span>        <span class="n">evictable_size_</span> <span class="o">-=</span> <span class="nb">len</span><span class="p">(</span><span class="n">node</span><span class="p">.</span><span class="n">key</span><span class="p">)</span>

</code></pre></div></div>

<h2 id="usage">Usage</h2>

<p>How to use the above API provided by <code class="language-plaintext highlighter-rouge">radix_cache_tree</code> ?</p>

<h3 id="cache">Cache</h3>

<p>When <code class="language-plaintext highlighter-rouge">prefill</code> is over,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">process_batch_result_prefill</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">result</span><span class="p">):</span>
    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">req</span><span class="p">,</span> <span class="n">next_token_id</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">batch</span><span class="p">.</span><span class="n">reqs</span><span class="p">,</span> <span class="n">result</span><span class="p">.</span><span class="n">next_token_ids</span><span class="p">):</span>
        <span class="n">req</span><span class="p">.</span><span class="n">output_ids</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">next_token_id</span><span class="p">)</span>
            <span class="n">req</span><span class="p">.</span><span class="n">check_finished</span><span class="p">()</span>

        <span class="k">if</span> <span class="n">req</span><span class="p">.</span><span class="n">finished</span><span class="p">():</span>
            <span class="n">tree_cache</span><span class="p">.</span><span class="n">cache_finished_req</span><span class="p">(</span><span class="n">req</span><span class="p">)</span>

        <span class="k">elif</span> <span class="ow">not</span> <span class="n">batch</span><span class="p">.</span><span class="n">decoding_reqs</span> <span class="ow">or</span> <span class="n">req</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">.</span><span class="n">decoding_reqs</span><span class="p">:</span>
            <span class="c1"># This updates radix so others can match
</span>            <span class="n">tree_cache</span><span class="p">.</span><span class="n">cache_unfinished_req</span><span class="p">(</span><span class="n">req</span><span class="p">)</span>
</code></pre></div></div>

<p>When <code class="language-plaintext highlighter-rouge">decode</code>  is over,</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">process_batch_result_decode</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">result</span><span class="p">):</span>
    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">req</span><span class="p">,</span> <span class="n">next_token_id</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">batch</span><span class="p">.</span><span class="n">reqs</span><span class="p">,</span> <span class="n">next_token_ids</span><span class="p">)):</span>
        <span class="n">req</span><span class="p">.</span><span class="n">check_finished</span><span class="p">()</span>

        <span class="k">if</span> <span class="n">req</span><span class="p">.</span><span class="n">finished</span><span class="p">():</span>
            <span class="n">tree_cache</span><span class="p">.</span><span class="n">cache_finished_req</span><span class="p">(</span><span class="n">req</span><span class="p">)</span>
</code></pre></div></div>

<p>💡 Only when <code class="language-plaintext highlighter-rouge">decode</code> finished, tree_cache cached its (<code class="language-plaintext highlighter-rouge">token_ids</code>, <code class="language-plaintext highlighter-rouge">out_cache_loc</code> )</p>

<h3 id="evict">Evict</h3>

<p>Evict, which is also free <code class="language-plaintext highlighter-rouge">out_cache_loc</code> , happened when available_size in <code class="language-plaintext highlighter-rouge">token_to_kv_pool</code> cannot support the incoming req</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">alloc_token_slots</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">backup_state</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="bp">False</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">token_to_kv_pool_allocator</span><span class="p">.</span><span class="n">available_size</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">num_tokens</span><span class="p">:</span>
      <span class="k">if</span> <span class="n">tree_cache</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
          <span class="n">tree_cache</span><span class="p">.</span><span class="n">evict</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">)</span>

  <span class="n">out_cache_loc</span> <span class="o">=</span> <span class="n">token_to_kv_pool_allocator</span><span class="p">.</span><span class="n">alloc</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">)</span>
</code></pre></div></div>

<h1 id="reference">Reference</h1>

<ul>
  <li><a href="https://hebiao064.github.io/fa3-attn-backend-basic">https://hebiao064.github.io/fa3-attn-backend-basic</a></li>
</ul>]]></content><author><name>Muqi Li</name></author><category term="SGLang-Mem-Cache" /><summary type="html"><![CDATA[Note: Complex systems often include numerous corner cases and technical implementations that can make the source code challenging to understand for newcomers. To make the core concepts more accessible, this blog post uses pseudocode that focuses on the main ideas while omitting implementation details (such as self references and other technical specifics). While simplified, the pseudocode maintains the essential logic and workflow of the system. Of source, if you want to know all details, the best way is to look directly at the source code, which is available here]]></summary></entry></feed>