<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://salykova.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://salykova.github.io/" rel="alternate" type="text/html" /><updated>2025-10-04T22:11:58+00:00</updated><id>https://salykova.github.io/feed.xml</id><title type="html">salykova</title><subtitle>making AI inference run really fast.</subtitle><author><name>Aman Salykov</name></author><entry><title type="html">Matrix Core Programming on AMD CDNA3 and CDNA4 architecture</title><link href="https://salykova.github.io/matrix-cores-cdna" rel="alternate" type="text/html" title="Matrix Core Programming on AMD CDNA3 and CDNA4 architecture" /><published>2025-09-30T23:35:01+00:00</published><updated>2025-09-30T23:35:01+00:00</updated><id>https://salykova.github.io/matrix-cores-cdna</id><content type="html" xml:base="https://salykova.github.io/matrix-cores-cdna"><![CDATA[<!---
Copyright (c) 2025 Advanced Micro Devices, Inc. (AMD)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
--->

<p><strong>TL;DR</strong> In this blog post, we walk through how to use Matrix Cores in HIP kernels, with a focus on low-precision data types such as FP16, FP8, and FP4, as well as the new family of Matrix Core instructions with exponent block scaling introduced in the AMD CDNA™4 architecture. Through code examples and illustrations, we provide the necessary knowledge to start programming Matrix Cores, covering modern low-precision floating-point types, the Matrix Core compiler intrinsics, and the data layouts required by the Matrix Core instructions. The blog post is also available on <a href="https://rocm.blogs.amd.com/software-tools-optimization/matrix-cores-cdna/README.html">ROCm Blogs</a>.</p>

<h2 id="1-matrix-cores">1. Matrix Cores</h2>

<p>Matrix multiplication is an essential part of AI and HPC workloads. The AMD CDNA™ architecture features special-purpose hardware, the Matrix Cores, to accelerate matrix fused-multiply-add (MFMA) operations defined as <code class="language-plaintext highlighter-rouge">D:=A*B+C</code>. Please note that MFMA instructions are often used to update a matrix in-place (=accumulation) so that <code class="language-plaintext highlighter-rouge">D=C</code> and <code class="language-plaintext highlighter-rouge">C:=A*B+C</code>. The matrices <code class="language-plaintext highlighter-rouge">A</code> and <code class="language-plaintext highlighter-rouge">B</code> are called input matrices, while the matrix <code class="language-plaintext highlighter-rouge">D</code> is referred to as the output matrix or accumulator.</p>

<p>The performance gains from using Matrix Cores are especially significant in mixed-precision mode, where the input matrices use lower-precision data types instead of FP32. The output matrix, however, is stored in FP32 to minimize accuracy loss during accumulation. The tables below show the theoretical peak performance of Matrix Cores with different input data types on both AMD CDNA™3 and AMD CDNA™4 architectures. On the AMD Instinct™ MI325X, using FP16 input matrices delivers nearly an 8x performance increase compared to single-precision, with only minimal accuracy loss. Switching to FP8 further doubles the performance providing a 16x increase when compared to FP32. The AMD CDNA™4 architecture further improves Matrix Core performance, delivering up to 2x higher throughput for FP16 and FP8 compared to the AMD CDNA™3 architecture. In addition, AMD CDNA™4 introduces new low-precision data types such as FP6 and FP4, enabling up to 64x performance gain relative to FP32. Please refer to the <a href="https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-3-white-paper.pdf">AMD CDNA™3</a> and <a href="https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-4-architecture-whitepaper.pdf">AMD CDNA™4</a> white papers for detailed architecture specifications.</p>

<!-- <p align="center"> -->

<table>
  <thead>
    <tr>
      <th>Type</th>
      <th>AMD Instinct™ MI325X (CDNA™3)</th>
      <th>Speedup vs. FP32</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Matrix FP64</td>
      <td>163.4 TF</td>
      <td>1x</td>
    </tr>
    <tr>
      <td>Matrix FP32</td>
      <td>163.4 TF</td>
      <td>1x</td>
    </tr>
    <tr>
      <td>Matrix FP16</td>
      <td>1307.4 TF</td>
      <td>~8x</td>
    </tr>
    <tr>
      <td>Matrix FP8</td>
      <td>2614.9 TF</td>
      <td>~16x</td>
    </tr>
  </tbody>
</table>

<table>
  <thead>
    <tr>
      <th>Type</th>
      <th>AMD Instinct™ MI355X (CDNA™4)</th>
      <th>Speedup vs. FP32</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Matrix FP64</td>
      <td>78.6 TF</td>
      <td>~0.5x</td>
    </tr>
    <tr>
      <td>Matrix FP32</td>
      <td>157.3 TF</td>
      <td>1x</td>
    </tr>
    <tr>
      <td>Matrix FP16</td>
      <td>2.5 PF</td>
      <td>~16x</td>
    </tr>
    <tr>
      <td>Matrix FP8</td>
      <td>5 PF</td>
      <td>~32x</td>
    </tr>
    <tr>
      <td>Matrix FP6</td>
      <td>10 PF</td>
      <td>~64x</td>
    </tr>
    <tr>
      <td>Matrix FP4</td>
      <td>10 PF</td>
      <td>~64x</td>
    </tr>
  </tbody>
</table>

<!-- </p> -->

<h2 id="2-low-precision-floating-point-types">2. Low-Precision Floating-Point Types</h2>

<p>A binary representation of a floating-point number consists of <code class="language-plaintext highlighter-rouge">n</code> bits, where <code class="language-plaintext highlighter-rouge">m</code> of <code class="language-plaintext highlighter-rouge">n</code> bits represent the mantissa, 1 bit determines the sign and <code class="language-plaintext highlighter-rouge">n-m-1</code> bits represent the exponent. The following image illustrates the binary format of a floating-point number and how the exponent and mantissa are calculated based on its binary representation.</p>

<p><img src="/assets/matrix_cores/binary_format_2.png" alt="binary_repr" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 1: Binary representation of a floating-point number.</em>
</p>

<p>Floating-point types are characterized by the number of bits used for the exponent and for the mantissa. Increasing the exponent width extends the range of representable values, while increasing the mantissa width improves precision. Since all floating-point types include the sign bit, a shorthand notation typically specifies only the exponent and mantissa widths. For example, the E4M3 type is an 8-bit floating-point type with 4-bit exponent and 3-bit mantissa. Additionally, a floating-point type is specified by exponent bias - a number that is subtracted from the exponent during conversion from binary format to real value. Given the exponent width, mantissa width, and exponent bias, one can convert the binary representation of a floating-point type (except E8M0) into its real value using the following equation:</p>

<p><img src="/assets/matrix_cores/g81.png" alt="fp_cvt" width="70%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 2: Conversion to real value from binary representation for floating-point numbers.</em>
</p>

<p>Please note that the equation takes different forms depending on whether the exponent is zero or not. Often, certain exponent and mantissa values are reserved for special values (e.g. <code class="language-plaintext highlighter-rouge">NaN</code>, <code class="language-plaintext highlighter-rouge">Infinity</code>), which limits the range of representable real numbers. For example, the FP16 type has 5-bit exponent with a nominal range of <code class="language-plaintext highlighter-rouge">[0, 1, ... 2^5-1] = [0, 1, ... 31]</code>. However, the exponent value <code class="language-plaintext highlighter-rouge">E = 31</code> is reserved for <code class="language-plaintext highlighter-rouge">NaN</code> (if the mantissa <code class="language-plaintext highlighter-rouge">M != 0</code>) and <code class="language-plaintext highlighter-rouge">infinity</code> (if the mantissa <code class="language-plaintext highlighter-rouge">M = 0</code>). Therefore, the largest exponent value that can represent a real number is <code class="language-plaintext highlighter-rouge">E = 30</code>.</p>

<p>The following table summarizes low-precision types commonly used in modern AI/ML workloads:</p>

<table>
  <thead>
    <tr>
      <th>Width</th>
      <th style="text-align: center">Shorthand</th>
      <th style="text-align: center">Exp. bias</th>
      <th>Range</th>
      <th style="text-align: center">Zero</th>
      <th style="text-align: center">NaN</th>
      <th style="text-align: center">Infinity</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>16-Bit</td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
      <td> </td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E5M10 (FP16)</td>
      <td style="text-align: center">15</td>
      <td>±65504</td>
      <td style="text-align: center">S 00000 0000000000</td>
      <td style="text-align: center">S 11111 xxxxxxxxxx</td>
      <td style="text-align: center">S 11111 0000000000</td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E8M7 (BF16)</td>
      <td style="text-align: center">127</td>
      <td>±3.3895 * 10^38</td>
      <td style="text-align: center">S 00000000 0000000</td>
      <td style="text-align: center">S 11111111 xxxxxxx</td>
      <td style="text-align: center">S 11111111 0000000</td>
    </tr>
    <tr>
      <td>8-Bit</td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
      <td> </td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E4M3FN (FP8, OCP)</td>
      <td style="text-align: center">7</td>
      <td>±448</td>
      <td style="text-align: center">S 0000 000</td>
      <td style="text-align: center">S 1111 111</td>
      <td style="text-align: center">n/a</td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E4M3FNUZ (FP8)</td>
      <td style="text-align: center">8</td>
      <td>±240</td>
      <td style="text-align: center">0 0000 000</td>
      <td style="text-align: center">1 0000 000</td>
      <td style="text-align: center">n/a</td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E5M2 (BF8, OCP)</td>
      <td style="text-align: center">15</td>
      <td>±57344</td>
      <td style="text-align: center">S 00000 00</td>
      <td style="text-align: center">S 11111 {01, 10 11}</td>
      <td style="text-align: center">S 11111 00</td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E5M2FNUZ (BF8)</td>
      <td style="text-align: center">16</td>
      <td>±57344</td>
      <td style="text-align: center">0 00000 00</td>
      <td style="text-align: center">S 00000 00</td>
      <td style="text-align: center">n/a</td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E8M0</td>
      <td style="text-align: center">127</td>
      <td>2^(±127)</td>
      <td style="text-align: center">n/a</td>
      <td style="text-align: center">11111111</td>
      <td style="text-align: center">n/a</td>
    </tr>
    <tr>
      <td>6-Bit</td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
      <td> </td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E2M3</td>
      <td style="text-align: center">1</td>
      <td>±7.5</td>
      <td style="text-align: center">S 00 000</td>
      <td style="text-align: center">n/a</td>
      <td style="text-align: center">n/a</td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E3M2 (BF6)</td>
      <td style="text-align: center">3</td>
      <td>±28</td>
      <td style="text-align: center">S 000 00</td>
      <td style="text-align: center">n/a</td>
      <td style="text-align: center">n/a</td>
    </tr>
    <tr>
      <td>4-Bit</td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
      <td> </td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
      <td style="text-align: center"> </td>
    </tr>
    <tr>
      <td> </td>
      <td style="text-align: center">E2M1 (FP4)</td>
      <td style="text-align: center">1</td>
      <td>±6</td>
      <td style="text-align: center">S 00 0</td>
      <td style="text-align: center">n/a</td>
      <td style="text-align: center">n/a</td>
    </tr>
  </tbody>
</table>

<p>Please note that the E4M3 type has two variants: E4M3FN and E4M3FNUZ. Both E4M3FN and E4M3FNUZ use 4 bits for the exponent and 3 bits for the mantissa. They use different exponent biases and differ in the special values they can represent. Neither variant supports infinities, which is why their notations include FN (FiNite). However, E4M3FN supports <code class="language-plaintext highlighter-rouge">+0</code>, <code class="language-plaintext highlighter-rouge">-0</code>, <code class="language-plaintext highlighter-rouge">+NaN</code> and <code class="language-plaintext highlighter-rouge">-Nan</code>, while E4M3FNUZ supports only <code class="language-plaintext highlighter-rouge">+0</code> and <code class="language-plaintext highlighter-rouge">NaN</code>, hence <code class="language-plaintext highlighter-rouge">UZ</code> (Unsigned Zero). The image below demonstrates how to convert a binary sequence into a real value, using E4M3FNUZ type as an example:</p>

<p><img src="/assets/matrix_cores/e4m3fnuz.png" alt="fp_cvt" width="90%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 3: E4M3FNUZ encoding details.</em>
</p>

<p>FP8 types are divided into E4M3 and E5M2 formats. The E5M2 format is sometimes referred to as BF8, similar to BF16, where exponent width is larger compared to FP16. Similar to E4M3, E5M2 is further subdivided into two variants: E5M2 (OCP) and E5M2FNUZ. The AMD CDNA™3 architecture uses FNUZ variants for both E4M3 and E5M2, whereas the CDNA™4 architecture uses E4M3FN and E5M2 (OCP) variants. E4M3FN and E5M2 are standardized formats defined by the Open Compute Project (OCP). For detailed specifications, see the <a href="https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf">OCP Microscaling Formats (MX) Specification</a> and the <a href="https://onnx.ai/onnx/technical/float8.html">ONNX documentation</a>. For visualization of FP8 values and their binary representations please refer to the <a href="https://asawicki.info/articles/fp8_tables.php">FP8 Data table</a>. Additionally, see the chapter <a href="https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html">“Low-precision floating-point types”</a> in the AMD ROCm™ documentation for details on using low-precision types in HIP.</p>

<p>There is a special 8-bit format, E8M0, which is not used as a standard element data type but instead serves as a scale factor for microscaling types and block-scaled MFMA operations (discussed later in this article). Its value is calculated according to the equation below:</p>

<p><img src="/assets/matrix_cores/e8m0.png" alt="fp_cvt" width="55%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 4: E8M0 encoding details.</em>
</p>

<p>The exponent value <code class="language-plaintext highlighter-rouge">E = 255</code> is reserved for <code class="language-plaintext highlighter-rouge">NaN</code> values, limiting the range of representable real numbers to <code class="language-plaintext highlighter-rouge">[2^-127 ... 2^127]</code>.</p>

<p>Similar to FP8, FP6 has two formats: E2M3 and E3M2. The latter, E3M2, is often referred to as BF6 due to its larger exponent width compared to E2M3.</p>

<h2 id="3-matrix-fused-multiply-add-mfma-instructions">3. Matrix fused-multiply-add (MFMA) Instructions</h2>

<p>The AMD CDNA™3 and CDNA™4 architectures support a variety of MFMA operations, which are characterized by the matrix dimensions <code class="language-plaintext highlighter-rouge">M</code>, <code class="language-plaintext highlighter-rouge">N</code>, <code class="language-plaintext highlighter-rouge">K</code> and the data type of input/output matrices. The following table lists all available floating-point MFMA instructions for the AMD CDNA™3 and CDNA™4 architectures. As can be seen from the table, the AMD CDNA™4 architecture extends the set of available MFMA instructions by adding new FP16/BF16 instructions with larger matrix dimensions. Furthermore, it introduces FP6/FP4 data types and provides a completely new set of FP8/FP6/FP4 instructions where the types can be independently used for the matrices <code class="language-plaintext highlighter-rouge">A</code> and <code class="language-plaintext highlighter-rouge">B</code>. Finally, the AMD CDNA™4 architecture enables MFMA with block exponent scaling.</p>

<table>
  <thead>
    <tr>
      <th style="text-align: center">Type (C,D) ← (A,B)</th>
      <th style="text-align: center">MxNxK (CDNA™3)</th>
      <th style="text-align: center">MxNxK (CDNA™4)</th>
      <th style="text-align: center">Cycles</th>
      <th>Note</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td style="text-align: center">FP64 ← FP64</td>
      <td style="text-align: center">16x16x4</td>
      <td style="text-align: center">16x16x4</td>
      <td style="text-align: center">64</td>
      <td> </td>
    </tr>
    <tr>
      <td style="text-align: center">FP32 ← FP32</td>
      <td style="text-align: center">32x32x2</td>
      <td style="text-align: center">32x32x2</td>
      <td style="text-align: center">64</td>
      <td> </td>
    </tr>
    <tr>
      <td style="text-align: center"> </td>
      <td style="text-align: center">16x16x4</td>
      <td style="text-align: center">16x16x4</td>
      <td style="text-align: center">32</td>
      <td> </td>
    </tr>
    <tr>
      <td style="text-align: center">FP32 ← FP16 (BF16)</td>
      <td style="text-align: center">32x32x8</td>
      <td style="text-align: center">32x32x8</td>
      <td style="text-align: center">32</td>
      <td>Both A and B are either FP16 or BF16</td>
    </tr>
    <tr>
      <td style="text-align: center"> </td>
      <td style="text-align: center">16x16x16</td>
      <td style="text-align: center">16x16x16</td>
      <td style="text-align: center">16</td>
      <td> </td>
    </tr>
    <tr>
      <td style="text-align: center"> </td>
      <td style="text-align: center">-</td>
      <td style="text-align: center">16x16x32</td>
      <td style="text-align: center">16</td>
      <td> </td>
    </tr>
    <tr>
      <td style="text-align: center"> </td>
      <td style="text-align: center">-</td>
      <td style="text-align: center">32x32x16</td>
      <td style="text-align: center">32</td>
      <td> </td>
    </tr>
    <tr>
      <td style="text-align: center">FP32 ← FP8</td>
      <td style="text-align: center">16x16x32</td>
      <td style="text-align: center">16x16x32</td>
      <td style="text-align: center">16</td>
      <td>FP8 (E4M3) or BF8 (E5M2) can be used independently for A and B</td>
    </tr>
    <tr>
      <td style="text-align: center"> </td>
      <td style="text-align: center">32x32x16</td>
      <td style="text-align: center">32x32x16</td>
      <td style="text-align: center">32</td>
      <td> </td>
    </tr>
    <tr>
      <td style="text-align: center">FP32 ← FP8/FP6/FP4</td>
      <td style="text-align: center">-</td>
      <td style="text-align: center">16x16x128</td>
      <td style="text-align: center">16 or 32</td>
      <td>FP4, FP6 or FP8 can be used independently for A and B. Larger cycle count if either matrix A or B is FP8</td>
    </tr>
    <tr>
      <td style="text-align: center"> </td>
      <td style="text-align: center">-</td>
      <td style="text-align: center">32x32x64</td>
      <td style="text-align: center">32 or 64</td>
      <td> </td>
    </tr>
    <tr>
      <td style="text-align: center">FP32 ← MXFP8/MXFP6/MXFP4</td>
      <td style="text-align: center">-</td>
      <td style="text-align: center">16x16x128</td>
      <td style="text-align: center">16 or 32</td>
      <td>FP4, FP6 or FP8 can be used independently for A and B. Larger cycle count if either matrix A or B is FP8</td>
    </tr>
    <tr>
      <td style="text-align: center"> </td>
      <td style="text-align: center">-</td>
      <td style="text-align: center">32x32x64</td>
      <td style="text-align: center">32 or 64</td>
      <td> </td>
    </tr>
  </tbody>
</table>

<p>Please note that the table lists only floating-point type MFMA instructions with batch size = 1. In addition to them, the AMD CDNA™3 and CDNA™4 architectures support batched MFMA operations, where multiple output matrices are computed in parallel. These instructions are not covered in this article. See the Chapter 7 “Matrix Arithmetic Instructions” in the <a href="https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf">AMD CDNA™3</a> and <a href="https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf">AMD CDNA™4</a> ISA reference guides for the full list of available MFMA instructions.</p>

<p>The table above specifies cycle count for each MFMA operation. Given a known cycle count, one can estimate theoretical peak performance in TFLOP/s of corresponding MFMA operation using the formula below:</p>

<p><code class="language-plaintext highlighter-rouge">
2*M*N*K * num_matrix_cores * (max_engine_clock / cycle_count) / 10^6,
</code></p>

<p>where</p>
<ol>
  <li><code class="language-plaintext highlighter-rouge">num_matrix_cores</code> is total number of matrix cores in a GPU (specified in white paper)</li>
  <li><code class="language-plaintext highlighter-rouge">max_engine_clock</code> is max engine clock (peak) in MHz (specified in white paper)</li>
  <li><code class="language-plaintext highlighter-rouge">cycle_count</code> is cycle count of corresponding MFMA instruction</li>
  <li><code class="language-plaintext highlighter-rouge">M, N, K</code> are matrix dimensions</li>
</ol>

<p>Using this formula and the MFMA instruction <code class="language-plaintext highlighter-rouge">32x32x8 FP16</code> as an example, we can estimate theoretical peak FP16 Matrix Core performance on the AMD Instinct™ MI325X:</p>

<p><code class="language-plaintext highlighter-rouge">2*32*32*8 * 1216 * (2100 / 32) / 10^6 = 1307.4 TFLOP/s</code>.</p>

<h2 id="4-compiler-intrinsics">4. Compiler Intrinsics</h2>

<p>To use Matrix Core instructions in HIP kernels, LLVM provides built-in compiler intrinsic functions. The list of all available compiler intrinsics can be found in the <a href="https://github.com/llvm/llvm-project/blob/main/clang/include/clang/Basic/BuiltinsAMDGPU.def">LLVM Github repository</a>. The syntax of the MFMA intrinsics has the following format:</p>

<p><code class="language-plaintext highlighter-rouge">d_reg = __builtin_amdgcn_mfma_ODType_MxNxKInDType(a_reg, b_reg, c_reg, cbsz, abid, blgp)</code>,</p>

<p>where</p>
<ol>
  <li><code class="language-plaintext highlighter-rouge">MxNxK</code> specifies the shapes of the matrices <code class="language-plaintext highlighter-rouge">A</code>, <code class="language-plaintext highlighter-rouge">B</code>, <code class="language-plaintext highlighter-rouge">C</code>, <code class="language-plaintext highlighter-rouge">D</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">ODType</code> is data type of the matrices <code class="language-plaintext highlighter-rouge">C</code> and <code class="language-plaintext highlighter-rouge">D</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">InDType</code> is data type of the input matrices <code class="language-plaintext highlighter-rouge">A</code> and <code class="language-plaintext highlighter-rouge">B</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">a_reg</code> is a scalar/vector containing a portion of the matrix <code class="language-plaintext highlighter-rouge">A</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">b_reg</code> is a scalar/vector containing a portion of the matrix <code class="language-plaintext highlighter-rouge">B</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">c_reg</code> is a vector containing a portion of the matrix <code class="language-plaintext highlighter-rouge">C</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">d_reg</code> is a vector containing a portion of the matrix <code class="language-plaintext highlighter-rouge">D</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">cbsz</code>, <code class="language-plaintext highlighter-rouge">abid</code>, <code class="language-plaintext highlighter-rouge">blgp</code> are broadcast flags. For the following discussion, these flags are irrelevant and are, therefore, set to 0 by default, unless specified otherwise. Please refer to the ISA reference guide for detailed information on the broadcast flags.</li>
</ol>

<p>For example,</p>

<ol>
  <li><code class="language-plaintext highlighter-rouge">__builtin_amdgcn_mfma_f32_16x16x16f16</code> performs <code class="language-plaintext highlighter-rouge">16x16x16</code> MFMA, where both input matrices <code class="language-plaintext highlighter-rouge">A</code> and <code class="language-plaintext highlighter-rouge">B</code> have type <code class="language-plaintext highlighter-rouge">FP16</code> and the output matrix has type <code class="language-plaintext highlighter-rouge">FP32</code></li>
  <li><code class="language-plaintext highlighter-rouge">__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8</code> performs <code class="language-plaintext highlighter-rouge">32x32x16</code> MFMA, where both input matrices <code class="language-plaintext highlighter-rouge">A</code> and <code class="language-plaintext highlighter-rouge">B</code> have type <code class="language-plaintext highlighter-rouge">FP8(E4M3)</code> and the output matrix is stored in <code class="language-plaintext highlighter-rouge">FP32</code></li>
  <li><code class="language-plaintext highlighter-rouge">__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8</code> performs <code class="language-plaintext highlighter-rouge">32x32x16</code> MFMA, where the matrix <code class="language-plaintext highlighter-rouge">A</code> has type <code class="language-plaintext highlighter-rouge">FP8(E4M3)</code> and the matrix <code class="language-plaintext highlighter-rouge">B</code> has type <code class="language-plaintext highlighter-rouge">BF8(E5M2)</code>.</li>
</ol>

<p>The MFMA instructions are wavefront-level (warp-level) instructions, where all work-items (threads) within a wavefront collectively perform a single MFMA operation and the operands <code class="language-plaintext highlighter-rouge">A</code>, <code class="language-plaintext highlighter-rouge">B</code>, <code class="language-plaintext highlighter-rouge">C</code>, <code class="language-plaintext highlighter-rouge">D</code> are distributed across work-items so that each work-item in the wavefront holds a portion of the operands. In order to use the MFMA instructions, it’s required to understand how the operands are distributed across threads within a wavefront. The ISA reference guide fully specifies the data layout for all available MFMA instructions. For illustrative purposes, the next chapter explains a subset of the MFMA instructions and the corresponding data layouts.</p>

<h2 id="5-examples">5. Examples</h2>

<blockquote>
  <p><strong>Important note:</strong> In the following discussion we assume the matrices are stored in row-major order. The wavefront size on the AMD CDNA™ architecture is 64. The shapes of the matrices <code class="language-plaintext highlighter-rouge">A</code>, <code class="language-plaintext highlighter-rouge">B</code>, <code class="language-plaintext highlighter-rouge">C</code>, <code class="language-plaintext highlighter-rouge">D</code> are <code class="language-plaintext highlighter-rouge">MxK</code>, <code class="language-plaintext highlighter-rouge">KxN</code>, <code class="language-plaintext highlighter-rouge">MxN</code>, and <code class="language-plaintext highlighter-rouge">MxN</code>, respectively. The first dimension denotes the number of rows and the second dimension the number of columns in a matrix. For example, the matrix <code class="language-plaintext highlighter-rouge">A</code> has <code class="language-plaintext highlighter-rouge">M</code> rows and <code class="language-plaintext highlighter-rouge">K</code> columns.</p>
</blockquote>

<h3 id="51-__builtin_amdgcn_mfma_f32_32x32x2f32">5.1. __builtin_amdgcn_mfma_f32_32x32x2f32</h3>

<p>In this example we will multiply matrix <code class="language-plaintext highlighter-rouge">A</code> of size <code class="language-plaintext highlighter-rouge">32x2</code> with matrix <code class="language-plaintext highlighter-rouge">B</code> of size <code class="language-plaintext highlighter-rouge">2x32</code> using single wavefront (64 threads) and single MFMA instruction. The output matrix <code class="language-plaintext highlighter-rouge">C</code> has shape <code class="language-plaintext highlighter-rouge">32x32</code>. The input and output matrices are FP32. Since threads within a wavefront collectively perform single MFMA instruction, the operands are distributed across the threads. Each thread stores</p>

<ol>
  <li><code class="language-plaintext highlighter-rouge">M * K / wavefront_size = 32 * 2 / 64 = 1</code> entries of the matrix <code class="language-plaintext highlighter-rouge">A</code></li>
  <li><code class="language-plaintext highlighter-rouge">K * N / wavefront_size = 2 * 32 / 64 = 1</code> entries of the matrix <code class="language-plaintext highlighter-rouge">B</code></li>
  <li><code class="language-plaintext highlighter-rouge">M * N / wavefront_size = 32 * 32 / 64 = 16</code> entries of the matrix <code class="language-plaintext highlighter-rouge">C</code></li>
</ol>

<p>The operands are distributed according to the scheme below. The matrix elements highlighted in light red are those stored by the thread with index <code class="language-plaintext highlighter-rouge">0</code> within the wavefront.</p>

<p><img src="/assets/matrix_cores/mfma_fp32_32x32x2_fp32.png" alt="fp_cvt" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 5: Data layout for `__builtin_amdgcn_mfma_f32_32x32x2f32`. The operands are stored in row-major order.</em>
</p>

<p>The code example below demonstrates how this operation can be implemented as a HIP kernel:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;hip/hip_runtime.h&gt;</span><span class="cp">
</span>
<span class="k">using</span> <span class="n">fp32x16_t</span> <span class="o">=</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">vector_size</span><span class="p">(</span><span class="mi">16</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">))))</span> <span class="kt">float</span><span class="p">;</span>

<span class="n">__global__</span> <span class="kt">void</span>
<span class="nf">mfma_fp32_32x32x2_fp32</span><span class="p">(</span><span class="k">const</span> <span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="k">const</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">)</span> <span class="p">{</span>
    <span class="kt">float</span> <span class="n">a_reg</span><span class="p">;</span>
    <span class="kt">float</span> <span class="n">b_reg</span><span class="p">;</span>
    <span class="n">fp32x16_t</span> <span class="n">c_reg</span> <span class="p">{};</span>

    <span class="k">const</span> <span class="kt">float</span><span class="o">*</span> <span class="n">ldg_a_ptr</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span><span class="p">);</span>
    <span class="k">const</span> <span class="kt">float</span><span class="o">*</span> <span class="n">ldg_b_ptr</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">32</span><span class="p">;</span>

    <span class="n">a_reg</span> <span class="o">=</span> <span class="o">*</span><span class="n">ldg_a_ptr</span><span class="p">;</span>
    <span class="n">b_reg</span> <span class="o">=</span> <span class="o">*</span><span class="n">ldg_b_ptr</span><span class="p">;</span>

    <span class="n">c_reg</span> <span class="o">=</span> <span class="n">__builtin_amdgcn_mfma_f32_32x32x2f32</span><span class="p">(</span><span class="n">a_reg</span><span class="p">,</span> <span class="n">b_reg</span><span class="p">,</span> <span class="n">c_reg</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">);</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">4</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span>          <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">1</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">2</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">3</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">3</span><span class="p">];</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>
<p>The GPU kernel can then be invoked on the host using a single wavefront:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">mfma_fp32_32x32x2_fp32</span><span class="o">&lt;&lt;&lt;</span><span class="mi">1</span><span class="p">,</span> <span class="mi">64</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">A_device</span><span class="p">,</span> <span class="n">B_device</span><span class="p">,</span> <span class="n">C_device</span><span class="p">);</span>
</code></pre></div></div>

<p>Please note that we use the vector data type <code class="language-plaintext highlighter-rouge">fp32x16_t</code> to store the entries of the matrix <code class="language-plaintext highlighter-rouge">C</code> in registers. Additionally, we zero-initialize <code class="language-plaintext highlighter-rouge">c</code>, since we compute <code class="language-plaintext highlighter-rouge">C = A * B</code> without accumulation.</p>

<h3 id="52-__builtin_amdgcn_mfma_f32_16x16x16f16">5.2. __builtin_amdgcn_mfma_f32_16x16x16f16</h3>

<p>This example demonstrates how to multiply matrix <code class="language-plaintext highlighter-rouge">A</code> of size <code class="language-plaintext highlighter-rouge">16x16</code> with matrix <code class="language-plaintext highlighter-rouge">B</code> of size <code class="language-plaintext highlighter-rouge">16x16</code> using single wavefront (64 threads) and single MFMA instruction. The output matrix <code class="language-plaintext highlighter-rouge">C</code> has shape <code class="language-plaintext highlighter-rouge">16x16</code>. The input matrices are stored in FP16 and the output matrix stored in FP32. In this case, each thread stores <code class="language-plaintext highlighter-rouge">4</code> entries of the matrix <code class="language-plaintext highlighter-rouge">A</code>, <code class="language-plaintext highlighter-rouge">4</code> entries of the matrix <code class="language-plaintext highlighter-rouge">B</code> and <code class="language-plaintext highlighter-rouge">4</code> entries of the matrix <code class="language-plaintext highlighter-rouge">C</code>. The data layout for this instruction is shown below. For illustrative purposes, the elements stored by the first thread within the wavefront are highlighted in red.</p>

<p><img src="/assets/matrix_cores/mfma_fp32_16x16x16_fp16.png" alt="fp_cvt" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 6: Data layout for __builtin_amdgcn_mfma_f32_16x16x16f16. The operands are stored in row-major order.</em>
</p>

<p>Corresponding HIP kernel is implemented below:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;hip/hip_runtime.h&gt;</span><span class="cp">
#include</span> <span class="cpf">&lt;hip/hip_fp16.h&gt;</span><span class="cp">
</span>
<span class="k">using</span> <span class="n">fp16_t</span> <span class="o">=</span> <span class="n">_Float16</span><span class="p">;</span>
<span class="k">using</span> <span class="n">fp16x4_t</span> <span class="o">=</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">vector_size</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="n">fp16_t</span><span class="p">))))</span> <span class="n">fp16_t</span><span class="p">;</span>
<span class="k">using</span> <span class="n">fp32x4_t</span> <span class="o">=</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">vector_size</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">))))</span> <span class="kt">float</span><span class="p">;</span>

<span class="n">__global__</span> <span class="kt">void</span>
<span class="nf">mfma_fp32_16x16x16_fp16</span><span class="p">(</span><span class="k">const</span> <span class="n">fp16_t</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="k">const</span> <span class="n">fp16_t</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">)</span> <span class="p">{</span>

    <span class="n">fp16x4_t</span> <span class="n">a_reg</span><span class="p">;</span>
    <span class="n">fp16x4_t</span> <span class="n">b_reg</span><span class="p">;</span>
    <span class="n">fp32x4_t</span> <span class="n">c_reg</span> <span class="p">{};</span>

    <span class="n">a_reg</span> <span class="o">=</span> <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="k">const</span> <span class="n">fp16x4_t</span><span class="o">*&gt;</span><span class="p">(</span><span class="n">A</span> <span class="o">+</span> <span class="mi">4</span> <span class="o">*</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">16</span><span class="p">)</span> <span class="o">+</span> <span class="mi">16</span> <span class="o">*</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">16</span><span class="p">));</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">4</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">b_reg</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="o">*</span><span class="p">(</span><span class="n">B</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">16</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">16</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">16</span><span class="p">)</span> <span class="o">*</span> <span class="mi">64</span><span class="p">);</span>
    <span class="p">}</span>

    <span class="n">c_reg</span> <span class="o">=</span> <span class="n">__builtin_amdgcn_mfma_f32_16x16x16f16</span><span class="p">(</span><span class="n">a_reg</span><span class="p">,</span> <span class="n">b_reg</span><span class="p">,</span> <span class="n">c_reg</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">);</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">4</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="o">*</span><span class="p">(</span><span class="n">C</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">16</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">16</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">16</span><span class="p">)</span> <span class="o">*</span> <span class="mi">64</span><span class="p">)</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Please note that both <code class="language-plaintext highlighter-rouge">__half</code> and <code class="language-plaintext highlighter-rouge">_Float16</code> types can be used in device code. However, the host supports only <code class="language-plaintext highlighter-rouge">_Float16</code> type for arithmetic operations. As in the previous example, we use vector data types to store the matrix elements in registers.</p>

<h3 id="53-__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8">5.3. __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8</h3>

<p>In this example we will multiply matrix <code class="language-plaintext highlighter-rouge">A</code> of size <code class="language-plaintext highlighter-rouge">32x16</code> with matrix <code class="language-plaintext highlighter-rouge">B</code> of size <code class="language-plaintext highlighter-rouge">16x32</code> using single wavefront (64 threads) and single MFMA instruction. The output matrix <code class="language-plaintext highlighter-rouge">C</code> has shape <code class="language-plaintext highlighter-rouge">32x32</code>. The input matrices are stored in FP8 and the output matrix is stored in FP32. In this scenario, each thread stores <code class="language-plaintext highlighter-rouge">8</code> elements of the matrix <code class="language-plaintext highlighter-rouge">A</code>, <code class="language-plaintext highlighter-rouge">8</code> elements of the matrix <code class="language-plaintext highlighter-rouge">B</code> and <code class="language-plaintext highlighter-rouge">16</code> elements of the matrix <code class="language-plaintext highlighter-rouge">C</code>. The operands are distributed according to the scheme below. For illustrative purposes, the elements stored by the first thread within the wavefront are highlighted in red.</p>

<p><img src="/assets/matrix_cores/mfma_fp32_32x32x16_fp8_fp8.png" alt="fp_cvt" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 7: Data layout for __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8. The operands are stored in row-major order.</em>
</p>

<p>The code example below implements this operation as a HIP kernel:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;hip/hip_runtime.h&gt;</span><span class="cp">
#include</span> <span class="cpf">&lt;hip/hip_fp8.h&gt;</span><span class="cp">
</span>
<span class="k">using</span> <span class="n">fp8_t</span> <span class="o">=</span> <span class="n">__hip_fp8_storage_t</span><span class="p">;</span>
<span class="k">using</span> <span class="n">fp8x8_t</span> <span class="o">=</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">vector_size</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="n">fp8_t</span><span class="p">))))</span> <span class="n">fp8_t</span><span class="p">;</span>
<span class="k">using</span> <span class="n">fp32x16_t</span> <span class="o">=</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">vector_size</span><span class="p">(</span><span class="mi">16</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">))))</span> <span class="kt">float</span><span class="p">;</span>

<span class="n">__global__</span> <span class="kt">void</span>
<span class="nf">mfma_fp32_32x32x16_fp8_fp8</span><span class="p">(</span><span class="k">const</span> <span class="n">fp8_t</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="k">const</span> <span class="n">fp8_t</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">fp8x8_t</span> <span class="n">a_reg</span><span class="p">;</span>
    <span class="n">fp8x8_t</span> <span class="n">b_reg</span><span class="p">;</span>
    <span class="n">fp32x16_t</span> <span class="n">c_reg</span> <span class="p">{};</span>

    <span class="n">a_reg</span> <span class="o">=</span> <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="k">const</span> <span class="n">fp8x8_t</span><span class="o">*&gt;</span><span class="p">(</span><span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">8</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">16</span><span class="p">);</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">8</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">b_reg</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="o">*</span><span class="p">(</span><span class="n">B</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">8</span> <span class="o">*</span> <span class="mi">32</span><span class="p">);</span>
    <span class="p">}</span>

    <span class="n">c_reg</span> <span class="o">=</span> <span class="n">__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8</span><span class="p">((</span><span class="kt">long</span><span class="p">)</span><span class="n">a_reg</span><span class="p">,</span> <span class="p">(</span><span class="kt">long</span><span class="p">)</span><span class="n">b_reg</span><span class="p">,</span> <span class="n">c_reg</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">);</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">4</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span>          <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">1</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">2</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">3</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">3</span><span class="p">];</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>
<p>To define FP8, we use <code class="language-plaintext highlighter-rouge">__hip_fp8_storage_t</code> type from <code class="language-plaintext highlighter-rouge">hip_fp8.h</code>. Note that the intrinsic function expects its first two operands to be of type <code class="language-plaintext highlighter-rouge">long</code>. To compile the code, the operands <code class="language-plaintext highlighter-rouge">a</code> and <code class="language-plaintext highlighter-rouge">b</code> are, therefore, converted to <code class="language-plaintext highlighter-rouge">long</code>.</p>

<h3 id="54-__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f8">5.4. __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f8</h3>

<blockquote>
  <p><strong>Important note:</strong> the MFMA instruction discussed in this example is supported only on AMD CDNA™4 GPUs (gfx950). Please make sure to install AMD ROCm™ version 7.0 or later.</p>
</blockquote>

<p>The AMD CDNA™4 architecture introduces a new family of MFMA instructions with block exponent scaling. The syntax of these instructions differs from the classic MFMA compiler intrinsics:</p>

<p><code class="language-plaintext highlighter-rouge">d_reg = __builtin_amdgcn_mfma_scale_f32_MxNxK_f8f6f4(a_reg, b_reg, c_reg, Atype, Btype, OPSEL_A, scale_a, OPSEL_B, scale_b)</code></p>

<ol>
  <li><code class="language-plaintext highlighter-rouge">MxNxK</code> specifies shapes of the matrices <code class="language-plaintext highlighter-rouge">A</code>, <code class="language-plaintext highlighter-rouge">B</code>, <code class="language-plaintext highlighter-rouge">C</code>, <code class="language-plaintext highlighter-rouge">D</code></li>
  <li><code class="language-plaintext highlighter-rouge">a_reg</code> is a vector containing elements of the matrix <code class="language-plaintext highlighter-rouge">A</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">b_reg</code> is a vector containing elements of the matrix <code class="language-plaintext highlighter-rouge">B</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">c_reg</code> is a vector containing elements of the matrix <code class="language-plaintext highlighter-rouge">C</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">d_reg</code> is a vector containing elements of the matrix <code class="language-plaintext highlighter-rouge">D</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">Atype</code> is an integer that specifies the data type of the matrix <code class="language-plaintext highlighter-rouge">A</code>. The following values are possible: <code class="language-plaintext highlighter-rouge">0 = E4M3 (fp8), 1 = E5M2(bf8), 2 = E2M3(fp6), 3 = E3M2(bf6), 4 = E2M1(fp4)</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">Btype</code> is an integer that specifies the data type of the matrix <code class="language-plaintext highlighter-rouge">B</code>. The following values are possible: <code class="language-plaintext highlighter-rouge">0 = E4M3 (fp8), 1 = E5M2(bf8), 2 = E2M3(fp6), 3 = E3M2(bf6), 4 = E2M1(fp4)</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">OPSEL_A</code>, <code class="language-plaintext highlighter-rouge">OPSEL_B</code> are OPSEL codes. These arguments are not relevant for the discussion and therefore will be set to <code class="language-plaintext highlighter-rouge">0</code>,</li>
  <li><code class="language-plaintext highlighter-rouge">scale_a</code>, <code class="language-plaintext highlighter-rouge">scale_b</code> are scalars / vectors containing scale factors of type <code class="language-plaintext highlighter-rouge">E8M0</code>.</li>
</ol>

<p>As an example, let’s take a closer look at the instruction <code class="language-plaintext highlighter-rouge">__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4</code>. The inputs to this instruction are</p>

<ol>
  <li>Matrix <code class="language-plaintext highlighter-rouge">A</code> of size <code class="language-plaintext highlighter-rouge">32x64</code></li>
  <li>Matrix <code class="language-plaintext highlighter-rouge">Ax</code> of size <code class="language-plaintext highlighter-rouge">32x2</code></li>
  <li>Matrix <code class="language-plaintext highlighter-rouge">B</code> of size <code class="language-plaintext highlighter-rouge">64x32</code></li>
  <li>Matrix <code class="language-plaintext highlighter-rouge">Bx</code> of size <code class="language-plaintext highlighter-rouge">2x32</code></li>
</ol>

<p>The output matrix <code class="language-plaintext highlighter-rouge">C</code> has shape <code class="language-plaintext highlighter-rouge">32x32</code>. Specifically, this instruction performs the following operation using single wavefront (64 threads):</p>

<p><img src="/assets/matrix_cores/block_scale_fp8.png" alt="fp_cvt" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 8: Block-scaled matrix multiplication via __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4.</em>
</p>

<p>During dot product operations, the scales <code class="language-plaintext highlighter-rouge">Ax</code>, <code class="language-plaintext highlighter-rouge">Bx</code> are applied after the normal dot product and prior to output/accumulation.</p>

<p>In this example, we will multiply two FP8 matrices using the <code class="language-plaintext highlighter-rouge">__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4</code> intrinsic function. The input matrices <code class="language-plaintext highlighter-rouge">A</code>, <code class="language-plaintext highlighter-rouge">B</code> are stored in FP8 format, while the output matrix is stored in FP32. The scale matrices <code class="language-plaintext highlighter-rouge">Ax</code>, <code class="language-plaintext highlighter-rouge">Bx</code> contain elements of type <code class="language-plaintext highlighter-rouge">E8M0</code>. Each thread stores <code class="language-plaintext highlighter-rouge">32</code> entries from the matrix <code class="language-plaintext highlighter-rouge">A</code>, <code class="language-plaintext highlighter-rouge">1</code> entry from the matrix <code class="language-plaintext highlighter-rouge">Ax</code>, <code class="language-plaintext highlighter-rouge">32</code> entries from the matrix <code class="language-plaintext highlighter-rouge">B</code>, <code class="language-plaintext highlighter-rouge">1</code> entry from the matrix <code class="language-plaintext highlighter-rouge">Bx</code> and <code class="language-plaintext highlighter-rouge">16</code> entries from the matrix <code class="language-plaintext highlighter-rouge">C</code>. The operands are distributed according to the scheme below. Please note that this scheme is valid only if both input matrices <code class="language-plaintext highlighter-rouge">A</code> and <code class="language-plaintext highlighter-rouge">B</code> have FP8 type. For illustrative purposes, the matrix elements stored by the thread with <code class="language-plaintext highlighter-rouge">threadIdx.x = 0</code> are highlighted in light red, while the elements stored by the thread with <code class="language-plaintext highlighter-rouge">threadIdx.x = 32</code> within the wavefront are highlighted in light green.</p>

<p><img src="/assets/matrix_cores/mfma_scale_fp32_32x32x64_fp8_fp8.png" alt="fp_cvt" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 9: Data layout for __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4 with FP8 input matrices. The operands are stored in row-major order.</em>
</p>

<p>The following code example shows how this operation can be implemented as a HIP kernel:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;hip/hip_runtime.h&gt;</span><span class="cp">
#include</span> <span class="cpf">&lt;hip/hip_ext_ocp.h&gt;</span><span class="cp">
</span>
<span class="k">using</span> <span class="n">fp8_t</span> <span class="o">=</span> <span class="n">__amd_fp8_storage_t</span><span class="p">;</span>
<span class="k">using</span> <span class="n">fp8x32_t</span> <span class="o">=</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">vector_size</span><span class="p">(</span><span class="mi">32</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="n">fp8_t</span><span class="p">))))</span> <span class="n">fp8_t</span><span class="p">;</span>
<span class="k">using</span> <span class="n">fp32x16_t</span> <span class="o">=</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">vector_size</span><span class="p">(</span><span class="mi">16</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">))))</span> <span class="kt">float</span><span class="p">;</span>

<span class="n">__global__</span> <span class="kt">void</span>
<span class="nf">mfma_fp32_32x32x64_fp8_fp8</span><span class="p">(</span><span class="k">const</span> <span class="n">fp8_t</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="k">const</span> <span class="n">fp8_t</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">fp8x32_t</span> <span class="n">a_reg</span><span class="p">;</span>
    <span class="n">fp8x32_t</span> <span class="n">b_reg</span><span class="p">;</span>
    <span class="n">fp32x16_t</span> <span class="n">c_reg</span> <span class="p">{};</span>

    <span class="k">const</span> <span class="n">fp8_t</span><span class="o">*</span> <span class="n">ldg_a</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">64</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">16</span><span class="p">;</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="mi">16</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">a_reg</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="mi">16</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="o">*</span><span class="p">(</span><span class="n">ldg_a</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="n">j</span><span class="p">);</span>
        <span class="p">}</span>
    <span class="p">}</span>

    <span class="k">const</span> <span class="n">fp8_t</span><span class="o">*</span> <span class="n">ldg_b</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">16</span> <span class="o">*</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">);</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="mi">2</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="mi">16</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">b_reg</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="mi">16</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="o">*</span><span class="p">(</span><span class="n">ldg_b</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="n">j</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">32</span><span class="p">);</span>
        <span class="p">}</span>
    <span class="p">}</span>

    <span class="kt">uint8_t</span> <span class="n">scale_a</span> <span class="o">=</span> <span class="mi">127</span><span class="p">;</span>
    <span class="kt">uint8_t</span> <span class="n">scale_b</span> <span class="o">=</span> <span class="mi">127</span><span class="p">;</span>

    <span class="n">c_reg</span> <span class="o">=</span> <span class="n">__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4</span><span class="p">(</span><span class="n">a_reg</span><span class="p">,</span> <span class="n">b_reg</span><span class="p">,</span> <span class="n">c_reg</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">scale_a</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">scale_b</span><span class="p">);</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">4</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span>          <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">1</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">2</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">3</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">3</span><span class="p">];</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>
<p>Please note that in this example we use <code class="language-plaintext highlighter-rouge">__amd_fp8_storage_t</code> type defined in <code class="language-plaintext highlighter-rouge">hip_ext_ocp.h</code> to represent FP8. This library provides extensions APIs for low-precision and micro-scaling formats, and compared to <code class="language-plaintext highlighter-rouge">hip_fp8.h</code>, exposes a wider capability set. <code class="language-plaintext highlighter-rouge">gfx950</code> provides hardware acceleration for these APIs. Most of the APIs are 1 to 1 mapping of hardware instruction. Additionally, we use <code class="language-plaintext highlighter-rouge">uint8_t</code> type to represent <code class="language-plaintext highlighter-rouge">E8M0</code> scale factors. Since <code class="language-plaintext highlighter-rouge">scale_a</code> and <code class="language-plaintext highlighter-rouge">scale_b</code> encode exponent values, the corresponding actual scale factors are <code class="language-plaintext highlighter-rouge">2^(scale_a - 127)</code> and <code class="language-plaintext highlighter-rouge">2^(scale_b - 127)</code>. If <code class="language-plaintext highlighter-rouge">scale_a = scale_b = 127</code>, the actual scale factors are equal to <code class="language-plaintext highlighter-rouge">1</code> and no scaling is applied.</p>

<h3 id="55-__builtin_amdgcn_mfma_scale_f32_32x32x64_f4f4">5.5. __builtin_amdgcn_mfma_scale_f32_32x32x64_f4f4</h3>

<p>In our last example, we demonstrate how to multiply two FP4 matrices using the <code class="language-plaintext highlighter-rouge">__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4</code> intrinsic function. As in the previous example, each thread stores <code class="language-plaintext highlighter-rouge">32</code> entries from the matrix <code class="language-plaintext highlighter-rouge">A</code>, <code class="language-plaintext highlighter-rouge">1</code> entry from the matrix <code class="language-plaintext highlighter-rouge">Ax</code>, <code class="language-plaintext highlighter-rouge">32</code> entries from the matrix <code class="language-plaintext highlighter-rouge">B</code>, <code class="language-plaintext highlighter-rouge">1</code> entry from the matrix <code class="language-plaintext highlighter-rouge">Bx</code> and <code class="language-plaintext highlighter-rouge">16</code> entries from the matrix <code class="language-plaintext highlighter-rouge">C</code>. The data layout for the output matrix remains the same as in the FP8 case. However, the data layout for the input matrices is different and depicted below. For illustrative purposes, the matrix elements stored by the thread with <code class="language-plaintext highlighter-rouge">threadIdx.x = 0</code> are highlighted in light red, while the elements stored by the thread with <code class="language-plaintext highlighter-rouge">threadIdx.x = 32</code> within the wavefront are highlighted in light green.</p>

<p><img src="/assets/matrix_cores/mfma_scale_fp32_32x32x64_fp4_fp4.png" alt="fp_cvt" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>
<p style="text-align:center">
<em>Figure 10: Data layout for __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4 with FP4 input matrices. The operands are stored in row-major order.</em>
</p>

<p>The code snippet below demonstrates how to implement this operation as a HIP kernel:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;hip/hip_runtime.h&gt;</span><span class="cp">
#include</span> <span class="cpf">&lt;hip/hip_ext_ocp.h&gt;</span><span class="cp">
</span>
<span class="k">using</span> <span class="n">fp4x2_t</span> <span class="o">=</span> <span class="n">__amd_fp4x2_storage_t</span><span class="p">;</span>
<span class="k">using</span> <span class="n">fp4x64_t</span>  <span class="o">=</span> <span class="n">fp4x2_t</span> <span class="nf">__attribute__</span><span class="p">((</span><span class="n">ext_vector_type</span><span class="p">(</span><span class="mi">32</span><span class="p">)));</span>
<span class="k">using</span> <span class="n">fp32x16_t</span> <span class="o">=</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">vector_size</span><span class="p">(</span><span class="mi">16</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">))))</span> <span class="kt">float</span><span class="p">;</span>

<span class="n">__global__</span> <span class="kt">void</span>
<span class="nf">mfma_fp32_32x32x64_fp4_fp4</span><span class="p">(</span><span class="k">const</span> <span class="n">fp4x2_t</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="k">const</span> <span class="n">fp4x2_t</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">)</span> <span class="p">{</span>

    <span class="n">fp4x64_t</span> <span class="n">a_reg</span> <span class="p">{};</span>
    <span class="n">fp4x64_t</span> <span class="n">b_reg</span> <span class="p">{};</span>
    <span class="n">fp32x16_t</span> <span class="n">c_reg</span> <span class="p">{};</span>

    <span class="k">const</span> <span class="n">fp4x2_t</span><span class="o">*</span> <span class="n">ldg_a</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">16</span><span class="p">;</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">16</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">a_reg</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="o">*</span><span class="p">(</span><span class="n">ldg_a</span> <span class="o">+</span> <span class="n">i</span><span class="p">);</span>
    <span class="p">}</span>

    <span class="k">const</span> <span class="n">fp4x2_t</span><span class="o">*</span> <span class="n">ldg_b</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">16</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">);</span>
    <span class="kt">int</span> <span class="n">b_extract_idx</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">2</span><span class="p">;</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">16</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="kt">uint8_t</span> <span class="n">tmp0</span> <span class="o">=</span> <span class="n">__amd_extract_fp4</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="n">ldg_b</span> <span class="o">+</span> <span class="mi">16</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">i</span><span class="p">),</span> <span class="n">b_extract_idx</span><span class="p">);</span>
        <span class="kt">uint8_t</span> <span class="n">tmp1</span> <span class="o">=</span> <span class="n">__amd_extract_fp4</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="n">ldg_b</span> <span class="o">+</span> <span class="mi">16</span> <span class="o">*</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)),</span> <span class="n">b_extract_idx</span><span class="p">);</span>
        <span class="n">b_reg</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">__amd_create_fp4x2</span><span class="p">(</span><span class="n">tmp0</span><span class="p">,</span> <span class="n">tmp1</span><span class="p">);</span>
    <span class="p">}</span>

    <span class="kt">uint8_t</span> <span class="n">scale_a</span> <span class="o">=</span> <span class="mi">127</span><span class="p">;</span>
    <span class="kt">uint8_t</span> <span class="n">scale_b</span> <span class="o">=</span> <span class="mi">127</span><span class="p">;</span>

    <span class="n">c_reg</span> <span class="o">=</span> <span class="n">__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4</span><span class="p">(</span><span class="n">a_reg</span><span class="p">,</span> <span class="n">b_reg</span><span class="p">,</span> <span class="n">c_reg</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">scale_a</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">scale_b</span><span class="p">);</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">4</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span>          <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">1</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">2</span><span class="p">];</span>
        <span class="n">C</span><span class="p">[</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">+</span> <span class="p">(</span><span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="mi">32</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">3</span> <span class="o">+</span> <span class="n">i</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">]</span> <span class="o">=</span> <span class="n">c_reg</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="mi">3</span><span class="p">];</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Since memory addressing is not allowed at a granularity smaller than 8 bits, we use <code class="language-plaintext highlighter-rouge">__amd_fp4x2_storage_t</code> (an alias for <code class="language-plaintext highlighter-rouge">uint8_t</code>) to store the input matrices and enable pointer operations. Note that the FP4 elements that need to be loaded from the matrix <code class="language-plaintext highlighter-rouge">B</code> are not contiguous in memory. To extract a single FP4 element, we use the <code class="language-plaintext highlighter-rouge">__amd_extract_fp4</code> function provided in <code class="language-plaintext highlighter-rouge">hip_ext_ocp.h</code>. This function returns one FP4 element (of type <code class="language-plaintext highlighter-rouge">uint8_t</code>) from a fp4x2 vector, based on the index passed as the second argument:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">uint8_t</span> <span class="nf">__amd_extract_fp4</span><span class="p">(</span><span class="k">const</span> <span class="n">__amd_fp4x2_storage_t</span> <span class="n">x</span><span class="p">,</span> <span class="k">const</span> <span class="kt">size_t</span> <span class="n">index</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">index</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">&amp;</span> <span class="mh">0xFu</span><span class="p">);</span>
    <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">&gt;&gt;</span> <span class="mi">4</span><span class="p">);</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Two FP4 values are then combined into <code class="language-plaintext highlighter-rouge">__amd_fp4x2_storage_t</code> using <code class="language-plaintext highlighter-rouge">__amd_create_fp4x2</code>:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__amd_fp4x2_storage_t</span> <span class="nf">__amd_create_fp4x2</span><span class="p">(</span><span class="k">const</span> <span class="kt">uint8_t</span> <span class="n">x</span><span class="p">,</span> <span class="k">const</span> <span class="kt">uint8_t</span> <span class="n">y</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">__amd_fp4x2_storage_t</span> <span class="n">ret</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
    <span class="n">ret</span> <span class="o">=</span> <span class="n">x</span> <span class="o">|</span> <span class="p">(</span><span class="n">y</span> <span class="o">&lt;&lt;</span> <span class="mi">4</span><span class="p">);</span>
    <span class="k">return</span> <span class="n">ret</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The compiler intrinsic function <code class="language-plaintext highlighter-rouge">__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4</code> requires its first two arguments to be 256 bits wide. Since 32 FP4 elements occupy only 128 bits, we define <code class="language-plaintext highlighter-rouge">fp4x64_t</code>, which is 256 bits wide. In this type, 128 bits contain data, while the remaining 128 bits are zero. This allows us to pass <code class="language-plaintext highlighter-rouge">a_reg</code> and <code class="language-plaintext highlighter-rouge">b_reg</code> to the intrinsic function and compile the code successfully.</p>

<h2 id="summary">Summary</h2>

<p>In this article, we introduced Matrix Core instructions available on the AMD CDNA™3 and CDNA™4 architectures. We covered floating-point formats in detail, including modern low-precision element data types such as FP8, FP6, FP4, and the scale data type E8M0. We further explained how the floating-point types are represented as binary sequences and demonstrated, with concrete examples, how to convert their binary representations into real values. Next, we listed Matrix Core instructions supported by the modern CDNA™ architectures and discussed how to calculate the theoretical peak performance of Matrix Cores for specific MFMA instructions. To make the discussion more practical, we reviewed the compiler intrinsic functions that allow users to program Matrix Cores inside HIP kernels. Finally, we examined a subset of MFMA instructions in detail, providing code examples and illustrations to explain data layout and demonstrate how to implement simple mixed-precision MFMA operations in HIP. For additional information on Matrix Cores and low-precision data types, please refer to the following resources:</p>

<ol>
  <li><a href="https://rocm.blogs.amd.com/software-tools-optimization/matrix-cores/README.html">Matrix Core Programming on CDNA2 - ROCm Blogs</a></li>
  <li><a href="https://gpuopen.com/learn/using_matrix_core_amd_rdna4">Using the Matrix Cores of AMD RDNA 4 architecture GPUs - GPUOpen Blogs</a></li>
  <li><a href="https://github.com/ROCm/amd_matrix_instruction_calculator">AMD Matrix Instruction Calculator</a></li>
  <li><a href="https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html">Low-Precision Floating Point Types - ROCm documentation</a></li>
</ol>]]></content><author><name>Amanzhol Salykov, Andy Luo, Carlus Huang, Peng Sun</name></author><summary type="html"><![CDATA[In this blog post, we walk through how to use Matrix Cores in HIP kernels, with a focus on low-precision data types such as FP16, FP8, and FP4, as well as the new family of Matrix Core instructions with exponent block scaling introduced in the AMD CDNA™4 architecture. Through code examples and illustrations, we provide the necessary knowledge to start programming Matrix Cores, covering modern low-precision floating-point types, the Matrix Core compiler intrinsics, and the data layouts required by the Matrix Core instructions.]]></summary></entry><entry><title type="html">Advanced Matrix Multiplication Optimization on NVIDIA GPUs</title><link href="https://salykova.github.io/gemm-gpu" rel="alternate" type="text/html" title="Advanced Matrix Multiplication Optimization on NVIDIA GPUs" /><published>2025-01-12T09:35:01+00:00</published><updated>2025-01-12T09:35:01+00:00</updated><id>https://salykova.github.io/gemm-gpu</id><content type="html" xml:base="https://salykova.github.io/gemm-gpu"><![CDATA[<blockquote>
  <p>This project is inspired by the outstanding works of Andrej Karpathy, George Hotz, Scott Gray, Horace He, Philippe Tillet, Jeremy Howard, Lei Mao and the best CUDA hackers from the <a href="https://github.com/gpu-mode">GPU MODE</a> community (<a href="https://discord.gg/gpumode">Discord server</a>). A special thanks to Mark Saroufim and Andreas Köpf for running GPU MODE and all you’ve done for the community.</p>
</blockquote>

<p>The code is available at <a href="https://github.com/salykova/sgemm.cu">sgemm.cu</a>. This article complements my <a href="https://salykova.github.io/gemm-cpu">blog post</a>, which covers the implementation of FP32 matrix multiplication that outperforms BLAS libraries on modern Intel and AMD CPUs. Today we’ll walk through a GPU implementation of SGEMM (Single-precision GEneral Matrix Multiply) operation defined as <code class="language-plaintext highlighter-rouge">C := alpha*A*B + beta*C</code>. The blog delves into benchmarking code on CUDA devices and explains the algorithm’s design along with optimization techniques. These include inlined PTX, asynchronous memory copies, double-buffering, avoiding shared memory bank conflicts, and efficient coalesced storage through shared memory. I’d also like to mention that the high-level algorithm design used in this project was developed by the excellent engineers at NVIDIA and has been extensively studied in prior works on cuBLAS and CUTLASS. My main contribution was translating it into efficient CUDA/PTX code. The goal of this project wasn’t to build an SGEMM that would magically outperform cuBLAS on all GPUs and all matrix sizes. This is especially pointless, given the open-sourced, lightweight CUTLASS library. Instead, the project primarily targets CUDA learners and aims to bridge the gap between the SGEMM implementations explained in books/blogs and those used in NVIDIA’s BLAS libraries. While the implementation is expected to deliver high performance on Ada/Ampere/Volta/Turing devices, it was specifically fine-tuned for and tested on a local NVIDIA RTX 3090 (=GA102 chip: RTX 3080, A10, A40, A6000). The achieved performance is shown below, comparing results with locked and unlocked GPU core frequencies against cuBLAS and Simon Boehm’s highly cited work (used in llamafile, aka tinyBLAS). I plan to continue publishing educational content on high-performance kernels used in AI/ML. Let me know what topics you’d like to see next! Projects currently in development: beating NVIDIA on Tensor Cores, Stream-K GEMM, FlashAttention, xLSTM. If you enjoy educational content like this and would like to see more, please share this article. Your feedback would be greatly appreciated!</p>

<p><strong>P.S. Please feel free to get in touch if you are interested in collaborating. My contact information is available on the homepage.</strong></p>

<p><br />
<img src="/assets/matmul_gpu/unlocked_perf.png" alt="unlocked_perf" width="100%" style="display:block; margin-left:auto; margin-right:auto" />
<br /><br />
<img src="/assets/matmul_gpu/locked_perf.png" alt="locked_perf" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<h2 id="1-introduction">1. Introduction</h2>

<p>I clearly remember Andrej’s post on the current state of the existing cuda learning materials vs. cuda code used in high-performance libraries:</p>

<p><img src="/assets/matmul_gpu/ak_post.png" alt="ak_post" width="65%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Indeed, when it comes to SGEMM implementations, there are some excellent educational blog posts, such as</p>

<ol>
  <li><a href="https://siboehm.com/articles/22/CUDA-MMM">How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance</a> (mentioned by Andrej)</li>
  <li><a href="https://leimao.github.io/article/CUDA-Matrix-Multiplication-Optimization/">CUDA Matrix Multiplication Optimization</a></li>
</ol>

<p>that break down, step by step, how to optimize a CUDA matmul kernel. However, in terms of achieved performance, none of them come close to matching the speed of cuBLAS or CUTLASS, especially when using recent CUDA versions and if benchmarked properly. From my experiments, these implementations achieve 50-70% of cuBLAS’ performance at best. Additionally, I found the explanations in both blog posts a bit overcomplicated in the final optimization steps. Nevertheless, I still think these resources are great for anyone starting with CUDA programming since they provide good foundational knowledge.</p>

<p>On the other hand, I’ve seen some really fast SGEMM implementations with cuBLAS-level performance:</p>

<ol>
  <li><a href="https://github.com/Yinghan-Li/YHs_Sample/tree/master/cuda/gemm">YHs GEMM</a></li>
  <li><a href="https://github.com/tpoisonooo/how-to-optimize-gemm/tree/master">how-to-optimize-gemm</a></li>
</ol>

<p>The problem is that they are undocumented, difficult to find and understand, especially for a CUDA beginner. A similar problem exists with CUTLASS. While it is highly performant, there is a lack of introductory or educational materials explaining how it is internally designed and implemented in efficient CUDA/PTX. Another notable project is <a href="https://github.com/NervanaSystems/maxas">MaxAs</a>, an assembler for the Maxwell architecture developed over a decade ago by Scott Gray. This tool enables programming directly in SASS (the assembly language for NVIDIA GPUs), allowing direct communication with the hardware instead of relying on the hardware-agnostic CUDA/PTX. Using MaxAs, Scott wrote an SGEMM implementation that achieved around 98% of the GM204 chip’s theoretical maximum FLOPS, surpassing cuBLAS by an average of 5%. While the results are impressive, programming in SASS is inflexible and requires deep understanding of the underlying hardware. Furthermore, with significant advancements in the compiler since then, programming directly in SASS is only advantageous in exceptional cases (for example, if you build <a href="https://github.com/tinygrad/tinygrad">tinygrad</a>). CUTLASS achieves performance on par with cuBLAS across various GPU architectures and matrix sizes using only CUDA/PTX code.</p>

<p>But can we actually exceed the cuBLAS barrier? In the following chapters, we will briefly review the high-level SGEMM design used in CUTLASS, and discuss how to translate this design into efficient CUDA/PTX. This guide assumes only a basic knowledge of the CUDA programming model and linear algebra. If you are new to CUDA programming, I strongly recommend starting with these short introductory articles:</p>

<ol>
  <li><a href="https://developer.nvidia.com/blog/easy-introduction-cuda-c-and-c/">An Easy Introduction to CUDA C and C++</a></li>
  <li><a href="https://developer.nvidia.com/blog/how-access-global-memory-efficiently-cuda-c-kernels/">How to Access Global Memory Efficiently in CUDA C/C++ Kernels</a></li>
  <li><a href="https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/">Using Shared Memory in CUDA C/C++</a></li>
  <li><a href="https://developer.nvidia.com/blog/cuda-pro-tip-increase-performance-with-vectorized-memory-access/">Increase Performance with Vectorized Memory Access</a></li>
</ol>

<p>Before we proceed with implementation, let’s talk about benchmarking code on NVIDIA GPUs - a topic often overlooked. Properly benchmarking code is just as important as the code itself, particularly when comparing different implementations.</p>

<h2 id="2-how-to-benchmark-code-on-cuda-devices">2. How to Benchmark Code on CUDA Devices?</h2>

<p>The most reliable way to measure kernel duration is by profiling with NVIDIA Nsight Compute and manually extracting performance data. To obtain deterministic and reproducible results, Nsight Compute automatically applies the following settings:</p>

<ol>
  <li>Clock Control: locks GPU clock frequencies to their base values</li>
  <li>Cache Control: flushes all GPU caches before each replay pass</li>
  <li>Persistence mode</li>
</ol>

<p>Alternatively, you can apply these settings manually and measure kernel duration at runtime without relying on external profilers. On Ubuntu, you can retrieve the base core clock frequency using:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>nvidia-smi base-clocks
</code></pre></div></div>

<p>For instance, on an RTX 3090, the base core clock frequency is 1395 MHz. Next, you’ll need the memory clock frequencies, which work in combination with the base core clock:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>nvidia-smi <span class="nt">-q</span> <span class="nt">-d</span> supported_clocks
</code></pre></div></div>

<p>From the list of supported frequencies, choose the fastest memory clock compatible with the base core frequency. Memory clock speeds are generally more stable than core clock speeds. To lock the clock frequencies and enable persistence mode, run the following commands:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nb">sudo </span>nvidia-smi <span class="nt">--persistence-mode</span><span class="o">=</span>1
<span class="c"># NVIDIA RTX 3090</span>
<span class="nb">sudo </span>nvidia-smi <span class="nt">--lock-gpu-clocks</span><span class="o">=</span>1395
<span class="nb">sudo </span>nvidia-smi <span class="nt">--lock-memory-clocks</span><span class="o">=</span>9501
</code></pre></div></div>

<p>To reset the core and memory clock frequencies, you can use:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nb">sudo </span>nvidia-smi <span class="nt">--reset-gpu-clocks</span>
<span class="nb">sudo </span>nvidia-smi <span class="nt">--reset-memory-clocks</span>
<span class="nb">sudo </span>nvidia-smi <span class="nt">--persistence-mode</span><span class="o">=</span>0
</code></pre></div></div>

<p>GPU clock frequencies may drop due to the GPU’s thermal state, but for high-performance applications, throttling is often caused by power limits. Faulty hardware can also lead to throttling. It’s a good idea to monitor the GPU’s state at least during a test run. Use the following command to keep track of power draw, clock speeds, and throttling reasons in real time:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>watch <span class="nt">-n</span> 0.1 nvidia-smi <span class="nt">--query-gpu</span><span class="o">=</span>power.draw,clocks.sm,clocks.mem,clocks_throttle_reasons.active <span class="nt">--format</span><span class="o">=</span>csv
</code></pre></div></div>

<p>A sample output might look like this:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>308.50 W, 1395 MHz, 9501 MHz, 0x0000000000000000
</code></pre></div></div>

<p>The bit mask <code class="language-plaintext highlighter-rouge">0x0000000000000000</code> indicates no throttling, and the clocks are running at their maximum speeds. A value of <code class="language-plaintext highlighter-rouge">0x0000000000000001</code> indicates an idle state. Any other values suggest throttling is occurring. For a full list of bit mask values and their meanings, refer to the <a href="https://docs.nvidia.com/deploy/nvml-api/group__nvmlClocksThrottleReasons.html">NvmlClocksThrottleReasons documentation</a>.</p>

<p>Once you’ve locked the clock frequencies, you can measure the kernel duration directly in CUDA using CUDA events. Here’s an example:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">cudaEvent_t</span> <span class="n">start</span><span class="p">,</span> <span class="n">stop</span><span class="p">;</span>
<span class="n">cudaEventCreate</span><span class="p">(</span><span class="o">&amp;</span><span class="n">start</span><span class="p">);</span> <span class="n">cudaEventCreate</span><span class="p">(</span><span class="o">&amp;</span><span class="n">stop</span><span class="p">);</span>
<span class="kt">float</span> <span class="n">elapsed_time_ms</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">;</span>

<span class="n">cudaEventRecord</span><span class="p">(</span><span class="n">start</span><span class="p">);</span>
<span class="n">kernel</span><span class="o">&lt;&lt;&lt;</span><span class="p">...</span><span class="o">&gt;&gt;&gt;</span><span class="p">(...);</span>
<span class="n">cudaEventRecord</span><span class="p">(</span><span class="n">stop</span><span class="p">);</span>

<span class="n">cudaEventSynchronize</span><span class="p">(</span><span class="n">stop</span><span class="p">);</span>
<span class="n">cudaEventElapsedTime</span><span class="p">(</span><span class="o">&amp;</span><span class="n">elapsed_time_ms</span><span class="p">,</span> <span class="n">start</span><span class="p">,</span> <span class="n">stop</span><span class="p">);</span>
</code></pre></div></div>

<p>For reliable measurements, multiple replay passes are typically used. In such cases, the GPU cache should be flushed before each kernel replay. This can be done using <code class="language-plaintext highlighter-rouge">cudaMemsetAsync</code> as shown in <a href="https://github.com/NVIDIA/nvbench/blob/main/nvbench/detail/l2flush.cuh">nvbench</a>:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Flush L2 cache</span>
<span class="kt">int</span> <span class="n">dev_id</span><span class="p">{};</span>
<span class="kt">int</span> <span class="n">m_l2_size</span><span class="p">{};</span>
<span class="kt">void</span><span class="o">*</span> <span class="n">buffer</span><span class="p">;</span>
<span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaGetDevice</span><span class="p">(</span><span class="o">&amp;</span><span class="n">dev_id</span><span class="p">));</span>
<span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaDeviceGetAttribute</span><span class="p">(</span><span class="o">&amp;</span><span class="n">m_l2_size</span><span class="p">,</span> <span class="n">cudaDevAttrL2CacheSize</span><span class="p">,</span> <span class="n">dev_id</span><span class="p">));</span>
<span class="k">if</span> <span class="p">(</span><span class="n">m_l2_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaMalloc</span><span class="p">(</span><span class="o">&amp;</span><span class="n">buffer</span><span class="p">,</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="kt">size_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">m_l2_size</span><span class="p">)));</span>
    <span class="kt">int</span><span class="o">*</span> <span class="n">m_l2_buffer</span> <span class="o">=</span> <span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">*&gt;</span><span class="p">(</span><span class="n">buffer</span><span class="p">);</span>
    <span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaMemsetAsync</span><span class="p">(</span><span class="n">m_l2_buffer</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="kt">size_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">m_l2_size</span><span class="p">)));</span>
    <span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaFree</span><span class="p">(</span><span class="n">m_l2_buffer</span><span class="p">));</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Locking the clock frequencies to their base values is a reliable way to measure the speed of your kernel. However, in real-world scenarios, algorithms don’t typically run with locked clocks. To achieve optimal performance, your algorithm needs to be both fast and power-efficient. The less power your algorithm consumes, the higher the clock speeds your hardware can maintain. NVIDIA GPUs often reduce clock frequencies aggressively, well before hitting their power limits, which can significantly degrade application performance. To account for this, we benchmark our implementation under both locked and unlocked clock conditions, testing for both speed and power efficiency. In our benchmarks, we evaluate matrix sizes ranging from 1024 to 12,800 with a step size of 128. For each matrix size, we launch <code class="language-plaintext highlighter-rouge">1000*exp((-matsize+1024)/3100.0))</code> kernel replays and calculate the execution time as the average of the second half of the replays. For example, given matrix size problem <code class="language-plaintext highlighter-rouge">m=n=k=4096</code>, we run the sgemm <code class="language-plaintext highlighter-rouge">1000*exp((-4096 + 1024)/3100.0))=371</code> times and measure the average duration of the last 185 runs, ensuring the clocks have stabilized. This profiling strategy leads to consistent and reproducible results, even when GPU clocks are unlocked.</p>

<blockquote>
  <p>Avoid using WSL for performance measurements. To ensure accurate and reliable results, please use a native Linux environment.</p>
</blockquote>

<h2 id="3-memory-layout">3. Memory Layout</h2>

<p>Without loss of generality in this implementation, we assume matrices are stored in row-major order. A matrix <code class="language-plaintext highlighter-rouge">A</code> with dimensions <code class="language-plaintext highlighter-rouge">M x N</code> is stored as contiguous array of length <code class="language-plaintext highlighter-rouge">M*N</code>. Elements <code class="language-plaintext highlighter-rouge">A[row][col]</code> are accessed via a 1D raw C pointer <code class="language-plaintext highlighter-rouge">ptr[row*N + col]</code> with <code class="language-plaintext highlighter-rouge">0&lt;=col&lt;=N-1</code> and <code class="language-plaintext highlighter-rouge">0&lt;=row&lt;=M-1</code>. Matrix multiplication is denoted as $C=AB$, where the shapes of matrices $A, B, C$ are $M \times K, K \times N$ and $M \times N$, respectively.
<img src="/assets/matmul_gpu/row_major.png" alt="mem_layout" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>To adapt this implementation for matrices stored in column-major order, simply swap the operands $A$ and $B$, because:</p>

\[C^\text{T} = (A B)^\text{T} = B^\text{T} A^\text{T},\]

<p>Here, $A, B, C$ are matrices stored in row-major order, while $A^\text{T}, B^\text{T}, C^\text{T}$ are the corresponding transposed matrices (i.e., stored in column-major order).</p>

<p>cuBLAS provides an API to calculate SGEMM:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">cublasSgemm</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">lda</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">ldb</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">ldc</span><span class="p">);</span> <span class="c1">// simplified form</span>
</code></pre></div></div>

<p>with <code class="language-plaintext highlighter-rouge">m, n, k</code> denote the matrix sizes $M, N, K$. The parameters <code class="language-plaintext highlighter-rouge">lda, ldb, ldc</code> are the <em>leading dimensions</em> of matrices $A, B, C$, respectively. The leading dimension is the length of the fastest-varying dimension when iterating over the matrix elements (i.e., the length of the first dimension). For matrices stored in row-major order, the leading dimension is usually the number of columns, so typically <code class="language-plaintext highlighter-rouge">lda=k, ldb=n, ldc=n</code>. However, this isn’t always the case. In scenarios where you need to compute a submatrix of a larger matrix, the leading dimension might be larger than the number of columns.</p>

<p>Matrices may be also padded with zeros to support vectorized memory loads or tensor cores. The vectorized load instructions allow to load multiple elements at once using just 1 instruction. Though the vectorized loads reduce total number of instructions and improve bandwidth utilization, they also impose <a href="https://developer.nvidia.com/blog/cuda-pro-tip-increase-performance-with-vectorized-memory-access/">alignment constraint</a> on input data, so that the leading dimension must be divisible by 2 (for 64-bit loads) or 4 (for 128-bit loads). The figure below illustrate the case for 128-bit (=4 floats) loads.</p>

<p><img src="/assets/matmul_gpu/mem_align.png" alt="mem_align" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Note how it’s impossible to load the elements of the first row without touching the elements of the next row if the leading dimension is not divisible by 4. Padding with zeros helps, but requires additional memory. Another solution would be to check at runtime if the leading dimension is divisible by 4. If it is - then use vectorized loads, if not - scalar loads. Additionally, zero padding was commonly used in the past to enable tensor core computations. For instance, in cuBLAS versions &lt; 11, Tensor Core FP16 operations required <code class="language-plaintext highlighter-rouge">m, n, k</code> to be multiples of 8.</p>

<h2 id="4-parallel-thread-execution">4. Parallel Thread Execution</h2>

<p>The CUDA compilation trajectory of a <code class="language-plaintext highlighter-rouge">.cu</code> file looks as follows:</p>

<p><img src="/assets/matmul_gpu/ptxas.png" alt="ptxas" width="60%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>During Stage 1 CUDA code is compiled to PTX (parallel thread execution) <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/">instructions</a> - intermediate high-level code, which can be considered as assembly for a <a href="https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architectures">virtual GPU architecture</a>. Such a virtual GPU is defined entirely by the set of capabilities, or features, that it provides to the application. PTX doesn’t run on any real architecture, directly. It must be optimized and translated to native target-architecture instructions (Stage 2). NVIDIA provides a mechanism to <a href="https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html">insert PTX code into your CUDA program</a>, so that you can mix CUDA/PTX in source code and still have benefits of code optimizations during the PTX generation. By rewriting parts of your code in PTX, you can 1) reduce total number of generated PTX instructions 2) exactly specify PTX instructions you need 3) tune the instructions through qualifiers 4) apply optimizations that are either lacking in the compiler or prohibited by C++ language extensions. <strong>Important! Using inline PTX Assembly will not make your code automatically faster than the one written in CUDA. It will only be faster if your hand-written PTX is better than the generated by the compiler</strong>.</p>

<p>In this implementation we will program some parts of the algorithm directly in PTX, so I highly recommend to check this <a href="https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html">short overview of inline ptx assembly</a> if you have never used it before. The PTX instructions are well documented and can be found at <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#instruction-set">PTX Instruction Set</a>. We will now briefly review the PTX instructions used in this implementation.</p>

<h3 id="41-global-memory-loads">4.1. Global Memory Loads</h3>

<p>For global memory loads we will use <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld">ld.global.f32</a> instruction. Here, “ld” denotes “load” and “f32” - “32-bit float”. The following CUDA code</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span> <span class="c1">// single float register</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="n">reg</span> <span class="o">=</span> <span class="o">*</span><span class="n">gmem_ptr</span><span class="p">;</span> <span class="c1">// global memory -&gt; register transfer</span>
</code></pre></div></div>

<p>can be implemented in PTX as:</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span> <span class="c1">// single float register</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"ld.global.f32 %0, [%1];"</span> <span class="o">:</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg</span><span class="p">)</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">gmem_ptr</span><span class="p">));</span>
</code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">f</code> in <code class="language-plaintext highlighter-rouge">"=f"</code> denotes <code class="language-plaintext highlighter-rouge">float</code> datatype and the <code class="language-plaintext highlighter-rouge">=</code> modifier specifies that the register is written to. The <code class="language-plaintext highlighter-rouge">l</code> represents unsigned 64-bit integer. We also use <code class="language-plaintext highlighter-rouge">volatile</code> keyword to ensure that the instruction is not deleted or moved during generation of PTX.</p>

<h3 id="42-global-memory-stores">4.2. Global Memory Stores</h3>

<p>For global memory stores there is <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st">`st.global.f32</a> instruction:</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span> <span class="c1">// single float register</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="c1">// *gmem_ptr = reg; can be implemented in PTX as:</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"st.global.f32 [%0], %1;"</span> <span class="o">:</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">gmem_ptr</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg</span><span class="p">));</span>
</code></pre></div></div>

<h3 id="43-global-to-shared-memory-transfers">4.3. Global to Shared Memory Transfers</h3>

<p>When you write something like this in CUDA:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__shared__</span> <span class="kt">float</span> <span class="n">smem_ptr</span><span class="p">[</span><span class="n">n</span><span class="p">];</span> <span class="c1">// pointer to shared memory</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="o">*</span><span class="n">smem_ptr</span> <span class="o">=</span> <span class="o">*</span><span class="n">gmem_ptr</span><span class="p">;</span> <span class="c1">// global to shared memory transfer</span>
</code></pre></div></div>

<p>a two-step process occurs. First, the data is fetched from global memory into registers and then that data is copied from registers into shared memory. Additionally the data is cached in all cache levels during the transfer.</p>

<p><img src="/assets/matmul_gpu/standard_ld.png" alt="standard_ld" width="30%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p><em style="display:block; margin-left:auto; margin-right:auto; text-align: center">Global to shared memory transfers</em></p>

<p>For this reason, a global to shared memory transfer in PTX consists of two data movement instructions <code class="language-plaintext highlighter-rouge">ld.global</code> and <code class="language-plaintext highlighter-rouge">st.shared</code>:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__shared__</span> <span class="kt">float</span> <span class="n">smem_ptr</span><span class="p">[</span><span class="n">n</span><span class="p">];</span> <span class="c1">// pointer to shared memory</span>
<span class="kt">uint64_t</span> <span class="n">smem_addr</span><span class="p">;</span>
<span class="c1">// convert generic address to shared address (store location for st.shared instruction)</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"cvta.to.shared.u64 %0, %1;"</span> <span class="o">:</span> <span class="s">"=l"</span><span class="p">(</span><span class="n">smem_addr</span><span class="p">)</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">smem_ptr</span><span class="p">));</span>

<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="kt">float</span> <span class="n">buffer</span><span class="p">;</span>
<span class="c1">// global memory -&gt; register</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"ld.global.f32 %0, [%1];"</span> <span class="o">:</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">buffer</span><span class="p">)</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">gmem_ptr</span><span class="p">));</span>
<span class="c1">// register -&gt; shared memory</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"st.shared.f32 [%0], %1;"</span> <span class="o">:</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">smem_addr</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">buffer</span><span class="p">));</span>
</code></pre></div></div>

<blockquote>
  <p>Prior to Ampere architecture it was not possible to transfer data from global memory directly to shared memory mitigating storing in registers. Starting from Ampere architecture, there are <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-asynchronous-copy">asynchronous copy instructions</a> that allow this. The usage of these instructions will be demonstrated later.</p>
</blockquote>

<h3 id="44-vectorized-shared-memory-loads-and-stores">4.4. Vectorized Shared Memory Loads and Stores</h3>

<p>In PTX you can also implement vectorized memory operations (loading/storing multiple elements with one instruction). Here, <code class="language-plaintext highlighter-rouge">v4</code> denotes vector with four elements:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg0</span><span class="p">,</span> <span class="n">reg1</span><span class="p">,</span> <span class="n">reg2</span><span class="p">,</span> <span class="n">reg3</span><span class="p">;</span>
<span class="kt">uint64_t</span> <span class="n">addr</span><span class="p">;</span>
<span class="p">...</span>
<span class="c1">// Shared memory 128-bit loads</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];"</span>
             <span class="o">:</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg0</span><span class="p">),</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg1</span><span class="p">),</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg2</span><span class="p">),</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg3</span><span class="p">)</span>
             <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">addr</span><span class="p">));</span>
<span class="c1">// Shared memory 128-bit stores</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"st.shared.v4.f32 [%0], {%1, %2, %3, %4};"</span>
             <span class="o">:</span>
             <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">addr</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg0</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg1</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg2</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg3</span><span class="p">));</span>
</code></pre></div></div>

<h3 id="45-predicated-execution">4.5. Predicated Execution</h3>

<p>In PTX conditional executions are implemented using optional guard predicates. The following CUDA code:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">ptr</span><span class="p">;</span> <span class="c1">//pointer to global memory</span>
<span class="kt">unsigned</span> <span class="n">guard</span><span class="p">;</span>
<span class="p">...</span>
<span class="k">if</span> <span class="p">(</span><span class="n">guard</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
    <span class="o">*</span><span class="n">ptr</span> <span class="o">=</span> <span class="n">reg</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div></div>

<p>can be converted to PTX as:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">ptr</span><span class="p">;</span>
<span class="kt">unsigned</span> <span class="n">guard</span><span class="p">;</span>
<span class="p">...</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">".reg .pred p;</span><span class="se">\n\t</span><span class="s">"</span> <span class="c1">// declare predicate 'p'</span>
             <span class="s">".setp.ne.u32 p, %2, 0;</span><span class="se">\n\t</span><span class="s">"</span> <span class="c1">// set 'p' to true if (guard != 0); ne="not equal"</span>
             <span class="s">"@p ld.global.f32 %0, [%1];</span><span class="se">\n\t</span><span class="s">"</span> <span class="c1">// execute instruction if 'p' is true</span>
             <span class="o">:</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg</span><span class="p">)</span>
             <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">ptr</span><span class="p">),</span> <span class="s">"r"</span><span class="p">(</span><span class="n">guard</span><span class="p">));</span>
</code></pre></div></div>

<p>We use guard predicates in combination with global load/store instructions to perform global memory access only if it is not out of bounds.</p>

<h2 id="5-sgemm-design">5. SGEMM Design</h2>

<p>Let’s now break down the high-level design of the algorithm. The paper <a href="https://www.semanticscholar.org/paper/Strassen%E2%80%99s-Algorithm-Reloaded-on-GPUs-Huang-Yu/d3214da488806bf4c870080fca18a7f3ecba1e99">Strassen’s Algorithm Reloaded on GPUs</a> contains, in my opinion, one of the best visualizations of the SGEMM design from the CUTLASS library. The SGEMM algorithm can be roughly divided into three main parts:</p>

<ol>
  <li>Transferring data from global to shared memory</li>
  <li>Loading data from shared memory and performing arithmetic operations</li>
  <li>Writing results back to global memory.</li>
</ol>

<p>Each of these steps must be carefully optimized to achieve high overall performance. In the following sections, we’ll explore each step in detail and discuss efficient implementation strategies. It’s worth mentioning that the first step - “transferring data from global memory to shared memory” is the most challenging to grasp. However, once you understand this part, the remaining steps become much easier to follow.</p>

<h3 id="51-transferring-data-from-global-to-shared-memory">5.1. Transferring data from global to shared memory</h3>

<p><img src="/assets/matmul_gpu/matmul_design_tb.png" alt="matmul_gmem_loads" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p><em style="display:block; margin-left:auto; margin-right:auto; text-align: center">Source: <a href="https://www.semanticscholar.org/paper/Strassen%E2%80%99s-Algorithm-Reloaded-on-GPUs-Huang-Yu/d3214da488806bf4c870080fca18a7f3ecba1e99">Strassen’s Algorithm Reloaded on GPUs</a></em></p>

<p>To parallelize $C=AB$ on GPU, the matrix $C$ is partitioned into sub-matrices $\tilde{C}$ of size $m_S \times n_S$ and the sub-matrices are processed in parallel with one thread block computing one sub-matrix $\tilde{C}$ independently from other thread blocks. To compute $\tilde{C}$, we iterate over the dimension $K$. In each iteration, a submatrix $\tilde{A}$ of size $m_s \times k_s$ and a submatrix $\tilde{B}$ of size $k_s \times n_s$ are loaded from <strong>global</strong> into <strong>shared</strong> memory (see the figure above). These submatrices are then multiplied, and the result is used to update $\tilde{C}$ as $\tilde{C} += \tilde{A} \tilde{B}$. The sub-matrices $\tilde{A}, \tilde{B}, \tilde{C}$ are often called <em>blocks</em> or <em>tiles</em>. In total there are $K / k_s$ iterations (assuming the simplest case, where $K$ is divisible by $k_s$). The limited shared memory capacity is the reason why the dimension $K$ is divided into smaller $k_s$ blocks. Full $m_s \times K, K \times n_s$ blocks simply wouldn’t fit available shared memory. For now, don’t be distracted by why the matrices are loaded into shared memory and how exactly the matrices $\tilde{A}, \tilde{B}$ are multiplied, we will discuss it in the next chapter. Let’s focus on the efficient data movement from global to shared memory as our first step towards fast SGEMM.</p>

<p>The pseudo code of the algorithm, from the perspective of a thread block, is as follows:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// The shapes of block_a, block_b, block_c are (ms x ks), (ks x ns), (ms x ns)</span>
<span class="c1">// Each thread block computes one block of C:</span>
<span class="n">block_c</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_a</span><span class="p">[</span><span class="n">block_a_size</span><span class="p">]</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_b</span><span class="p">[</span><span class="n">block_b_size</span><span class="p">]</span>
<span class="k">for</span> <span class="p">(</span><span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">K</span><span class="o">/</span><span class="n">ks</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">block_a</span> <span class="o">=</span> <span class="n">load</span> <span class="n">ith</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span> <span class="c1">// from global into shared memory</span>
    <span class="n">block_b</span> <span class="o">=</span> <span class="n">load</span> <span class="n">ith</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span> <span class="c1">// from global into shared memory</span>
    <span class="n">block_c</span> <span class="o">+=</span> <span class="n">block_a</span> <span class="o">*</span> <span class="n">block_b</span> <span class="c1">// compute matrix product and update block_c</span>
<span class="p">}</span>
<span class="n">store</span><span class="p">(</span><span class="n">block_c</span><span class="p">)</span> <span class="c1">// store to global memory</span>
</code></pre></div></div>

<p>Data transfers from global memory to shared memory have significantly higher latency compared to arithmetic operations. During this time, threads are forced to stall, idly waiting for the data needed to compute <code class="language-plaintext highlighter-rouge">block_a * block_b</code>. One way to mitigate this latency is by overlapping data transfers with computations, leveraging instruction-level parallelism (ILP). In GEMM implementations, a technique known as <em>double buffering</em> is commonly used to achieve this overlap:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">block_c</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1">// Shared Memory Double buffering</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_a</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">block_a_size</span><span class="p">]</span> <span class="c1">// 2x shared memory usage</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_b</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">block_b_size</span><span class="p">]</span> <span class="c1">// 2x shared memory usage</span>
<span class="n">block_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span>
<span class="n">block_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span>

<span class="k">for</span> <span class="p">(</span><span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="p">(</span><span class="n">K</span><span class="o">/</span><span class="n">ks</span><span class="o">-</span><span class="mi">1</span><span class="p">);</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">idx</span> <span class="o">=</span> <span class="n">i</span><span class="o">%</span><span class="mi">2</span>
    <span class="n">prefetch_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">%</span><span class="mi">2</span>
    <span class="c1">// prefetch next blocks</span>
    <span class="n">block_a</span><span class="p">[</span><span class="n">prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span>
    <span class="n">block_a</span><span class="p">[</span><span class="n">prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span>
    <span class="c1">// use blocks loaded in previous iteration to calculate matrix product</span>
    <span class="n">block_c</span> <span class="o">+=</span> <span class="n">block_a</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">*</span> <span class="n">block_b</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="p">}</span>
<span class="c1">// final update of the accumulator using last blocks</span>
<span class="n">block_c</span> <span class="o">+=</span> <span class="n">block_a</span><span class="p">[</span><span class="n">prefetch_idx</span><span class="p">]</span> <span class="o">*</span> <span class="n">block_b</span><span class="p">[</span><span class="n">prefetch_idx</span><span class="p">]</span>

<span class="n">store_to_global_memory</span><span class="p">(</span><span class="n">block_c</span><span class="p">)</span>
</code></pre></div></div>

<p>Note that <code class="language-plaintext highlighter-rouge">block_c += block_a[idx] * block_b[idx]</code> doesn’t depend on <code class="language-plaintext highlighter-rouge">blocks[prefetch_idx]</code> allowing the arithmetic instructions to be issued in parallel with the data movement instructions. However, this comes at the cost of doubled shared memory usage, as we need to store two blocks instead of one. The good news is that modern GPUs have <a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability">sufficient shared memory</a> to support double-buffering.</p>

<p>We’ve already introduced several parameters such as block sizes $m_s, k_s, n_s$ and number of threads per thread block. The choice of these parameters highly depends on the shapes of the operands $A, B, C$, as well as the underlying GPU architecture. For example, cuBLAS implements multiple SGEMM kernels optimized for various matrix shapes and GPU architectures. At runtime, it selects the most appropriate kernel using a heuristic approach. The block sizes $m_s, k_s, n_s$ affect not only how the data will be fetched from global memory, but also how the work in all subsequent steps (shared memory loads, arithmetic operations, global memory stores) is organized among the threads to achieve the best possible performance. The choice of the block sizes and the number of threads per thread block also impact shared memory / register usage, which can result in decreased performance if not taken into account. As you might expect, identifying optimal parameter values requires excellent understanding of hardware and extensive experimentation. Fortunately, SGEMM is a well-studied problem and we can use the results from previous studies of cuBLAS and CUTLASS. For large square matrices (<code class="language-plaintext highlighter-rouge">M=N=K &gt; 1024</code>) the combinations of $m_S \times n_S$ such as $128 \times 256$, $128 \times 128$ and $256 \times 128$ lead to optimal performance. From my tests, the configuration $m_s \times n_s \times k_s = 128 \times 128 \times 8$ with <strong>256 threads per thread block</strong> achieved the highest TFLOP/S on my local RTX 3090 for matrix size problems <code class="language-plaintext highlighter-rouge">1024 &lt;= M=N=K &lt;= 2500</code>. Therefore, we will start with implementation of a <code class="language-plaintext highlighter-rouge">128x128x8</code> SGEMM kernel. Now that we know the block dimensions and the number of threads per thread block, let’s discuss how to efficiently organize data loading from global memory and storage into shared memory.</p>

<p>First, we need to load <code class="language-plaintext highlighter-rouge">128x8</code> submatrix $\tilde{A}$ using 256 threads. This results in each thread loading <code class="language-plaintext highlighter-rouge">128*8/256 = 4</code> float elements from global memory. There are several different ways to organize the loading of the block. For global memory reads/stores you always want your accesses to be contiguous or <em>coalesced</em>, so that 32 threads in a wrap access 32 consecutive floats in memory. If a memory access is coalesced the minimum number of memory transactions will be used. However, it is not possible in case of the $\tilde{A}$ block: each row of the block contains only 8 consecutive elements. Nevertheless, even in such cases, consecutive threads in a wrap accessing consecutive elements in memory is preferable and usually results in better performance. The figure below shows how the loading of the block $\tilde{A}$ is implemented. Here, different colors represent different threads, whereas only first 16 threads are shown for simplicity. Four consecutive rows are loaded by 8 consecutive threads: the rows 1-4 are loaded by threads 0-7, the rows 5-8 are loaded by threads 8-15, the rows 9-12 are loaded by threads 16-23 and so on, with the last rows 125-128 are loaded by threads 248-255. We also transpose the block $\tilde{A}$ while storing in shared memory for better memory access pattern during the next computation step. Note how each thread stores 4 consecutive elements in shared memory. This allows us to use PTX vectorized stores <code class="language-plaintext highlighter-rouge">st.shared.v4.f32</code>.</p>

<p><img src="/assets/matmul_gpu/a_gmem_loads.png" alt="a_gmem_loads" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Storing to shared memory using this naive scheme would result in <strong>shared memory bank conflicts</strong>. From the CUDA programming guide:</p>
<blockquote>
  <p>To achieve high bandwidth, shared memory is divided into equally-sized memory modules, called banks, which can be accessed simultaneously. Any memory read or write request made of n addresses that fall in n distinct memory banks can therefore be serviced simultaneously, yielding an overall bandwidth that is n times as high as the bandwidth of a single module.
However, if two addresses of a memory request fall in the same memory bank, there is a bank conflict and the access has to be serialized. The hardware splits a memory request with bank conflicts into as many separate conflict-free requests as necessary, decreasing throughput by a factor equal to the number of separate memory requests. If the number of separate memory requests is n, the initial memory request is said to cause n-way bank conflicts.</p>
</blockquote>

<p>Shared memory has 32 banks that are organized such that successive 32-bit words map to successive banks. Imagine a <code class="language-plaintext highlighter-rouge">float32</code> array of size <code class="language-plaintext highlighter-rouge">8x32</code> stored in <strong>row-major</strong> order as shown below.</p>

<p><img src="/assets/matmul_gpu/bank_conflict.png" alt="bank_conflict" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>In this context, colors and their shades represent memory banks: each row corresponds to 32 distinct memory banks, while each column represents a single memory bank. Here are two important notes about shared memory bank conflicts:</p>

<ol>
  <li>The determination of bank conflicts is made <strong>per memory transaction (or using modern CUDA language - per wave)</strong>, <strong>not</strong> per request, <strong>not</strong> per warp, <strong>not</strong> per instruction.</li>
  <li>Two requests to the same bank and the same 32-bit location in that bank do <strong>not</strong> create a bank conflict (<a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-5-x">illustrated</a> in the CUDA programming guide).</li>
</ol>

<p>When you store (or load) 4 bytes(= 1 float) per thread, which is 4*32=128 bytes per warp, a CUDA device issues a single memory transaction (warp-wide) so that the shared memory access must be conflict-free across the whole wrap(=32 threads). In our case, we store 16 bytes(= 4 floats) per thread using the vector instructions. Warp-wide that will be a total of 512 bytes per request. The GPU splits the request into 4 memory transactions (threads 0-7 make up a transaction, threads 8-15 a transaction and so on), each of which is 128 byte wide. If we would store according to our scheme, then each thread within threads 0-7 would store to the same four columns (red color shades) or with other words to the same four memory banks causing bank conflicts. The same applies for other memory transactions i.e. threads 8-15, threads 16-23 and so on. One possible way to completely avoid bank conflicts would be to pad the leading dimension with 16 bytes (=4 floats) as shown below.</p>

<p><img src="/assets/matmul_gpu/bank_conflict_padding.png" alt="bank_conflict_padding" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Now, if we store the data according to our scheme, each thread within threads 0-7 would accesses distinct memory banks, resulting in 32 memory banks being accessed per memory transaction. The same applies for the remaining memory transactions i.e. t8-t15, t16-t23 and so on. This is the reason why the leading dimension is 132 and not 128 in the implementation:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">const</span> <span class="kt">int</span> <span class="n">smem_a_ld</span> <span class="o">=</span> <span class="mi">132</span><span class="p">;</span> <span class="c1">// 128 + 4</span>
</code></pre></div></div>

<p>To implement double-buffering and store two $\tilde{A}$ blocks, theoretically, we would need shared memory of size <code class="language-plaintext highlighter-rouge">2*132*8*4</code> bytes. However, we increase the size to the nearest power of 2 = <code class="language-plaintext highlighter-rouge">2*256*8*4</code> to enable fast switching. Compare the following code with the pseudocode presented at the beginning of the chapter:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Double-buffering (blocks_b is omitted for simplicity)</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">__align__</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="mi">256</span><span class="o">*</span><span class="mi">8</span><span class="o">*</span><span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">))</span> <span class="n">blocks_a</span><span class="p">[</span><span class="mi">2</span><span class="o">*</span><span class="mi">256</span><span class="o">*</span><span class="mi">8</span><span class="p">]</span>
<span class="kt">uint64_t</span> <span class="n">lds_a_addr</span><span class="p">;</span>
<span class="kt">uint64_t</span> <span class="n">sts_a_addr</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">lds_a_ptr</span> <span class="o">=</span> <span class="n">blocks_a</span><span class="p">;</span> <span class="c1">// lds = load shared</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">sts_a_ptr</span> <span class="o">=</span> <span class="n">blocks_a</span><span class="p">;</span> <span class="c1">// sts = store shared</span>
<span class="n">lds_a_addr</span> <span class="o">=</span> <span class="n">convert_to_addr</span><span class="p">(</span><span class="n">lds_a_ptr</span><span class="p">);</span> <span class="c1">// convert pointer to address for PTX load/store instructions</span>
<span class="n">sts_a_addr</span> <span class="o">=</span> <span class="n">convert_to_addr</span><span class="p">(</span><span class="n">sts_a_ptr</span><span class="p">);</span> <span class="c1">// convert pointer to address for PTX load/store instructions</span>

<span class="c1">// store first block to first half of shared memory</span>
<span class="n">sts_ptx</span><span class="p">(</span><span class="n">sts_a_addr</span><span class="p">);</span>
<span class="c1">// switch address to second half of shared memory</span>
<span class="n">sts_a_addr</span> <span class="o">^=</span> <span class="mi">8192</span><span class="p">;</span>

<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="p">(</span><span class="n">K</span><span class="o">/</span><span class="n">ks</span><span class="o">-</span><span class="mi">1</span><span class="p">);</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="p">...</span>
    <span class="c1">// store next block to second(first) half of shared memory</span>
    <span class="n">sts_ptx</span><span class="p">(</span><span class="n">sts_a_addr</span><span class="p">);</span>
    <span class="p">...</span>
    <span class="c1">// load block from first(second) half of shared memory to compute c+=block_a*block_b</span>
    <span class="n">lds_ptx</span><span class="p">(</span><span class="n">lds_a_addr</span><span class="p">);</span>
    <span class="p">...</span>
    <span class="c1">// swap the addresses for next iteration: lds_a_addr = sts_a_addr, sts_a_addr = lds_a_addr</span>
    <span class="n">lds_a_addr</span> <span class="o">^=</span> <span class="mi">8192</span><span class="p">;</span>
    <span class="n">sts_a_addr</span> <span class="o">^=</span> <span class="mi">8192</span><span class="p">;</span>
    <span class="p">...</span>
<span class="p">}</span>
<span class="p">...</span>
</code></pre></div></div>

<p><img src="/assets/matmul_gpu/db.png" alt="db" width="70%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>First, we require <code class="language-plaintext highlighter-rouge">blocks_a</code> to be <code class="language-plaintext highlighter-rouge">2*256*8*4</code>=<code class="language-plaintext highlighter-rouge">2^14</code>=<code class="language-plaintext highlighter-rouge">16384</code>-byte aligned. This implies the address of the first element of <code class="language-plaintext highlighter-rouge">blocks_a</code> to be divisible by 16384 or with other words the last 14 bits of the address are zero:</p>

<p><img src="/assets/matmul_gpu/bit_repr.png" alt="bit_repr" width="60%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>As each block size is <code class="language-plaintext highlighter-rouge">8192=2^13</code> bytes, switching between the blocks can now be implemented with just a single XOR instruction <code class="language-plaintext highlighter-rouge">^= 8192</code>. The only drawback of this method is the unused shared memory (in this case <code class="language-plaintext highlighter-rouge">2*8*128*4</code> bytes). However, this can be ignored considering <a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability">maximum amount of shared memory per thread block</a> on modern GPUs.</p>

<p>Loading and storing a <code class="language-plaintext highlighter-rouge">8 x 128</code> submatrix $\tilde{B}$ is much simpler to manage due to its shape. Since the sub-matrix must not be transposed, the loading and storing schemes are identical:</p>

<p><img src="/assets/matmul_gpu/b_gmem_loads.png" alt="b_gmem_loads" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>We use 32 consecutive threads to load 32 consecutive elements, with each thread loading 4 elements, spaced apart by a stride of 32. Note that since we store data in 32 distinct shared memory banks, no padding is required, and bank conflicts are avoided. Furthermore, the block size <code class="language-plaintext highlighter-rouge">128*8</code> is naturally a power of two, eliminating the need for additional padding and allowing block switching with a single XOR <code class="language-plaintext highlighter-rouge">^=4096</code> instruction.</p>

<h3 id="52-shared-memory-loads-and-arithmetic-operations">5.2. Shared Memory Loads and Arithmetic Operations</h3>

<p>With blocks $\tilde{A}$ and $\tilde{B}$ now residing in shared memory, let’s discuss how to efficiently load from shared memory and compute block $\tilde{C}$. To do this, we’ll dive one level deeper into our parallelization strategy and describe the algorithm from a warp’s perspective:</p>

<p><img src="/assets/matmul_gpu/warp_level_design.png" alt="warp_level_design" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Launched thread block consists of 256 threads, which corresponds to <code class="language-plaintext highlighter-rouge">256/32=8</code> warps. The block $\tilde{C}$, with dimensions $128 \times 128$, is, therefore, divided into 8 regions $\tilde{C}_W$ labeled $W1, …, W8$ in the figure. Each region $\tilde{C}_W$ has dimensions $m_W \times n_W = 32 \times 64$ and is computed by a single warp: $W1$ is computed by threads <code class="language-plaintext highlighter-rouge">t0-t31</code>, $W2$ is computed by threads <code class="language-plaintext highlighter-rouge">t32-t63</code>, and so on, with $W8$ computed by threads <code class="language-plaintext highlighter-rouge">t224-t255</code>. The figure above uses $W8$ as an example to demonstrate how a single $\tilde{C}_W$ region is computed. We iterate over the dimension $K$ and in each iteration we</p>

<ol>
  <li>load <code class="language-plaintext highlighter-rouge">fragment_a</code> (=column of size $m_W \times 1$) from $\tilde{A}$ into registers</li>
  <li>load <code class="language-plaintext highlighter-rouge">fragment_b</code> (=row of size $1 \times n_W$) from $\tilde{B}$ into registers</li>
  <li>multiply the fragments and update $\tilde{C}_W$</li>
</ol>

<p>As $k_S = 8$, there will be in total 8 iterations. This explanation is from the perspective of a warp. Now, let’s delve one final level deeper and examine how the work within a warp is distributed among its 32 threads.</p>

<p><img src="/assets/matmul_gpu/thread_level_design.png" alt="thread_level_design" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Each thread in a wrap computes four <code class="language-plaintext highlighter-rouge">4x4</code> sub-matrices (=accumulators) within $\tilde{C}_W$ or if concatenated - <code class="language-plaintext highlighter-rouge">8x8</code> accumulator. To do this, each thread loads 8 elements from <code class="language-plaintext highlighter-rouge">fragment_a</code>, 8 elements from <code class="language-plaintext highlighter-rouge">fragment_b</code> (as illustrated for thread <code class="language-plaintext highlighter-rouge">t0</code> in the figure), multiplies them and updates the accumulator using fused multiply-add (FMA) instructions. Since <code class="language-plaintext highlighter-rouge">block_a</code> was transposed in the previous step, the elements in <code class="language-plaintext highlighter-rouge">fragment_a</code> are stored contiguously in memory, allowing faster access through vectorized loads. The threads are arranged in a way that avoids bank conflicts and works around NVIDIA’s shared memory broadcast limitation. This limitation occurs when 4 floats loaded using 16-byte vector instruction must be broadcast to more than 4 consecutive threads within a warp.</p>

<p>Bringing everything together, the entire SGEMM algorithm can be visualized as follows:
<img src="/assets/matmul_gpu/cutlass_sgemm.png" alt="cutlass_sgemm" width="70%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>As you might expect, the accumulators are frequently updated during the computation and need to be stored in the fastest memory - the register files. Each thread allocates <code class="language-plaintext highlighter-rouge">float accumulator[8][8]</code>, so that the entire block $\tilde{C}$ of size $128 \times 128$ is stored in registers by the <code class="language-plaintext highlighter-rouge">256</code> threads. This works because <code class="language-plaintext highlighter-rouge">256=16*16</code>, and the combined arrangement <code class="language-plaintext highlighter-rouge">(16*8)x(16*8)=128x128</code> matches the size of $\tilde{C}$. Just as we used double buffering to load the blocks $\tilde{A}$ and $\tilde{B}$ (from global memory to shared memory), we now also double buffer the fragments to minimize memory transfer latencies when moving data from shared memory to registers. The pseudocode for the algorithm can be written as follows:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Pseudocode</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_a</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">block_a_size</span><span class="p">]</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_b</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">block_b_size</span><span class="p">]</span>
<span class="kt">float</span> <span class="n">fragment_a</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="mi">8</span><span class="p">]</span>
<span class="kt">float</span> <span class="n">fragment_b</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="mi">8</span><span class="p">]</span>
<span class="kt">float</span> <span class="n">accumulator</span><span class="p">[</span><span class="mi">8</span><span class="p">][</span><span class="mi">8</span><span class="p">]</span>

<span class="n">block_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span>
<span class="n">block_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span>
<span class="n">fragment_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">fragment_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

<span class="k">for</span> <span class="p">(</span><span class="n">block_k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">block_k</span><span class="o">&lt;</span><span class="p">(</span><span class="n">K</span><span class="o">/</span><span class="n">ks</span><span class="o">-</span><span class="mi">1</span><span class="p">);</span> <span class="n">block_k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">block_idx</span> <span class="o">=</span> <span class="n">block_k</span> <span class="o">%</span> <span class="mi">2</span>
    <span class="n">block_prefetch_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">block_k</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
    <span class="c1">// prefetch next blocks (Shared Memory Double buffering)</span>
    <span class="n">block_a</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span>
    <span class="n">block_a</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">warp_k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">warp_k</span><span class="o">&lt;</span><span class="mi">8</span><span class="p">;</span> <span class="n">warp_k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">frag_idx</span> <span class="o">=</span> <span class="n">warp_k</span> <span class="o">%</span> <span class="mi">2</span>
        <span class="n">frag_prefetch_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">warp_k</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
        <span class="c1">// prefetch next fragments (Register Double buffering)</span>
        <span class="n">fragment_a</span><span class="p">[</span><span class="n">frag_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_a</span><span class="p">[</span><span class="n">block_idx</span><span class="p">]</span>
        <span class="n">fragment_b</span><span class="p">[</span><span class="n">frag_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_b</span><span class="p">[</span><span class="n">block_idx</span><span class="p">]</span>
        <span class="c1">// use fragments loaded in previous iteration to calculate matrix product</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">8</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="mi">8</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
                <span class="n">accumulator</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">fragment_a</span><span class="p">[</span><span class="n">frag_idx</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">fragment_b</span><span class="p">[</span><span class="n">frag_idx</span><span class="p">][</span><span class="n">j</span><span class="p">];</span>
            <span class="p">}</span>
        <span class="p">}</span>
    <span class="p">}</span>
    <span class="n">fragment_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_a</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span>
    <span class="n">fragment_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_b</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span>
<span class="p">}</span>

<span class="c1">// final update of the accumulator using last blocks</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">warp_k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">warp_k</span><span class="o">&lt;</span><span class="mi">8</span><span class="p">;</span> <span class="n">warp_k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">frag_idx</span> <span class="o">=</span> <span class="n">warp_k</span> <span class="o">%</span> <span class="mi">2</span>
    <span class="n">frag_prefetch_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">warp_k</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
    <span class="c1">// prefetch next fragments (Register Double buffering)</span>
    <span class="n">fragment_a</span><span class="p">[</span><span class="n">frag_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_a</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span>
    <span class="n">fragment_b</span><span class="p">[</span><span class="n">frag_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_b</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span>
    <span class="c1">// use fragments loaded in previous iteration to calculate matrix product</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">8</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="mi">8</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">accumulator</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">fragment_a</span><span class="p">[</span><span class="n">frag_idx</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">fragment_b</span><span class="p">[</span><span class="n">frag_idx</span><span class="p">][</span><span class="n">j</span><span class="p">];</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>

<span class="c1">// After completing the matrix multiplication C=A*B, we perform one final update to the accumulator</span>
<span class="c1">// to compute  C=alpha*A*B before storing the result back to global memory:</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="mi">8</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="mi">8</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">accumulator</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">*=</span> <span class="n">alpha</span><span class="p">;</span>
    <span class="p">}</span>
<span class="p">}</span>

<span class="n">store_to_global_memory</span><span class="p">(</span><span class="n">accumulator</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="53-coalesced-global-memory-stores-through-shared-memory">5.3. Coalesced Global Memory Stores Through Shared Memory</h3>

<p>Just as with global memory reads, we want our global memory writes to be coalesced. However, directly storing the accumulators to global memory based on our current mapping</p>

<p><img src="/assets/matmul_gpu/acc_map.png" alt="acc_map" width="70%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>would result in random memory accesses, significantly hurting performance. To fix this, we use shared memory as a buffer to rearrange the accumulators, enabling coalesced global memory writes. At this stage, the accumulators have already been computed, so we no longer need shared memory for computation. Transferring data from registers to shared memory is fast. The overhead of these additional transfers from registers to shared memory is negligible compared to the performance gains achieved through coalesced writes. We write the accumulator’s elements to shared memory row by row according to the following scheme:</p>

<p><img src="/assets/matmul_gpu/stgx.png" alt="stg" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The first row, containing 32 elements, is copied to the first 32 consecutive memory addresses in shared memory. Similarly, the second row is copied to the next 32 consecutive memory addresses, and so on with all 16 rows have been copied to shared memory. Next, we iterate through the rows in shared memory, and in each iteration, we store a row (containing 32 elements) to global memory using coalesced writes:</p>

<p><img src="/assets/matmul_gpu/stg_final.png" alt="stg_final" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The process is then repeated for the other three <code class="language-plaintext highlighter-rouge">4x4</code> accumulators of the threads.</p>

<p>To compute $C := \alpha AB + \beta C$, we make a slight adjustment to the process of storing the data to global memory. After copying the accumulator from registers to shared memory, we check if <code class="language-plaintext highlighter-rouge">beta != 0.0</code>. If true, we load (using coalesced loads) the corresponding element from global memory into a register, multiply it by <code class="language-plaintext highlighter-rouge">beta</code> and add the result to the accumulator stored in shared memory. Finally, we store the updated accumulator <code class="language-plaintext highlighter-rouge">alpha*A*B+beta*C</code> from shared memory to global memory using coalesced writes.</p>

<h3 id="6-performance-analysis">6. Performance Analysis</h3>

<p>So far, we have discussed the design of the <code class="language-plaintext highlighter-rouge">128x128x8</code> SGEMM kernel. Its implementation is available at <a href="https://github.com/salykova/sgemm.cu/blob/main/src/kernels/128x128x8.cuh">128x128x8.cuh</a> and closely follows the pseudo-code outlined earlier. Let’s now benchmark this kernel to evaluate its performance. First, we conduct a benchmark with locked clock frequencies:
<img src="/assets/matmul_gpu/128x128x8_lock.png" alt="128x128x8_lock" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The benchmark results show that the implementation outperforms cuBLAS when clock speeds remain constant. However, performance alone is not enough; we also need to consider power consumption. To evaluate both metrics, we run the benchmark with unlocked clock frequencies:
<img src="/assets/matmul_gpu/128x128x8.png" alt="128x128x8" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>This reveals the effect of throttling due to reaching power limits. While the <code class="language-plaintext highlighter-rouge">128x128x8</code> kernel is, on average, 3–4% faster than cuBLAS, it consumes 12% more power. The increased power consumption causes the GPU to operate near the power limit for matrix sizes <code class="language-plaintext highlighter-rouge">m=n=k&gt;4000</code>, resulting in reduced clock speeds and overall performance degradation. This the reason why optimizing <strong>both</strong> running time and power consumption is required for achieving a balanced and efficient implementation.</p>

<p>We can slightly improve the running time of the kernel by utilizing vectorized global texture loads. The new kernel is available at <a href="https://github.com/salykova/sgemm.cu/blob/main/src/kernels/128x128x8_texld.cuh">128x128x8_texld</a>. Since the vectorized load instructions impose alignment constraints on the input data, we first verify the memory alignment and ensure the leading dimensions of matrices <code class="language-plaintext highlighter-rouge">A</code> and <code class="language-plaintext highlighter-rouge">B</code> are divisible by 4:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">bool</span> <span class="n">is_aligned</span> <span class="o">=</span> <span class="p">(((</span><span class="kt">unsigned</span><span class="p">)</span><span class="n">lda</span> <span class="o">&amp;</span> <span class="mi">3u</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;&amp;</span> <span class="p">(((</span><span class="kt">unsigned</span><span class="p">)</span><span class="n">ldb</span> <span class="o">&amp;</span> <span class="mi">3u</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span>
                    <span class="o">&amp;&amp;</span> <span class="p">(((</span><span class="kt">unsigned</span> <span class="kt">long</span><span class="p">)</span><span class="n">A</span> <span class="o">&amp;</span> <span class="mi">15u</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;&amp;</span> <span class="p">(((</span><span class="kt">unsigned</span> <span class="kt">long</span><span class="p">)</span><span class="n">B</span> <span class="o">&amp;</span> <span class="mi">15u</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">);</span>
</code></pre></div></div>

<p>If the input data is aligned, we can use the vectorized load instructions. First we need to create texture objects, texture descriptors, and resource descriptors. These are configured to handle <code class="language-plaintext highlighter-rouge">float</code> data type with four 32-bit channels (x, y, z, w). The texture objects are then bound to the operands <code class="language-plaintext highlighter-rouge">A, B</code>, and passed to the kernel instead of raw pointers <code class="language-plaintext highlighter-rouge">A, B</code>.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">cudaResourceDesc</span> <span class="n">resDesc</span><span class="p">;</span>
<span class="n">cudaTextureDesc</span> <span class="n">texDesc</span><span class="p">;</span>
<span class="n">cudaTextureObject_t</span> <span class="n">tex_a</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="n">cudaTextureObject_t</span> <span class="n">tex_b</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="p">...</span>
<span class="k">if</span> <span class="p">(</span><span class="n">is_aligned</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">memset</span><span class="p">(</span><span class="o">&amp;</span><span class="n">texDesc</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="k">sizeof</span><span class="p">(</span><span class="n">texDesc</span><span class="p">));</span>
    <span class="n">texDesc</span><span class="p">.</span><span class="n">readMode</span> <span class="o">=</span> <span class="n">cudaReadModeElementType</span><span class="p">;</span>
    <span class="n">texDesc</span><span class="p">.</span><span class="n">normalizedCoords</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
    <span class="n">memset</span><span class="p">(</span><span class="o">&amp;</span><span class="n">resDesc</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="k">sizeof</span><span class="p">(</span><span class="n">resDesc</span><span class="p">));</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">resType</span> <span class="o">=</span> <span class="n">cudaResourceTypeLinear</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">f</span> <span class="o">=</span> <span class="n">cudaChannelFormatKindFloat</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">x</span> <span class="o">=</span> <span class="mi">32</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">y</span> <span class="o">=</span> <span class="mi">32</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">z</span> <span class="o">=</span> <span class="mi">32</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">w</span> <span class="o">=</span> <span class="mi">32</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">devPtr</span> <span class="o">=</span> <span class="n">A</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">sizeInBytes</span> <span class="o">=</span> <span class="n">m</span> <span class="o">*</span> <span class="n">lda</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">);</span>
    <span class="n">cudaCreateTextureObject</span><span class="p">(</span><span class="o">&amp;</span><span class="n">tex_a</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">resDesc</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">texDesc</span><span class="p">,</span> <span class="nb">NULL</span><span class="p">);</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">devPtr</span> <span class="o">=</span> <span class="n">B</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">sizeInBytes</span> <span class="o">=</span> <span class="n">k</span> <span class="o">*</span> <span class="n">ldb</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">);</span>
    <span class="n">cudaCreateTextureObject</span><span class="p">(</span><span class="o">&amp;</span><span class="n">tex_b</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">resDesc</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">texDesc</span><span class="p">,</span> <span class="nb">NULL</span><span class="p">);</span>
    <span class="n">sgemm_texld_128x128x8</span><span class="o">&lt;&lt;&lt;</span><span class="n">grid</span><span class="p">,</span> <span class="n">threads</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">m</span><span class="p">,</span>
                                             <span class="n">n</span><span class="p">,</span>
                                             <span class="n">k</span><span class="p">,</span>
                                             <span class="o">*</span><span class="n">alpha</span><span class="p">,</span>
                                             <span class="n">tex_a</span><span class="p">,</span>
                                             <span class="n">lda</span><span class="p">,</span>
                                             <span class="n">tex_b</span><span class="p">,</span>
                                             <span class="n">ldb</span><span class="p">,</span>
                                             <span class="o">*</span><span class="n">beta</span><span class="p">,</span>
                                             <span class="n">C</span><span class="p">,</span>
                                             <span class="n">ldc</span><span class="p">);</span>
    <span class="n">cudaDestroyTextureObject</span><span class="p">(</span><span class="n">tex_a</span><span class="p">);</span>
    <span class="n">cudaDestroyTextureObject</span><span class="p">(</span><span class="n">tex_b</span><span class="p">);</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Within the kernel, we load data through texture objects using the <code class="language-plaintext highlighter-rouge">tex1Dfetch</code> function, which compiles to a single PTX instruction:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">float4</span> <span class="n">texld_a_buffer</span><span class="p">;</span>
<span class="n">texld_a_buffer</span> <span class="o">=</span> <span class="n">tex1Dfetch</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">&gt;</span><span class="p">(</span><span class="n">tex_a</span><span class="p">,</span> <span class="n">texld_a_offset</span><span class="p">);</span>
</code></pre></div></div>

<p>We use global texture loads over normal vectorized global loads (<code class="language-plaintext highlighter-rouge">ld.global.v4.f32</code>) because texture loads handle out-of-bounds reads gracefully by returning zeros, avoiding the need for predicated execution. This simplification leads to more efficient code:
<img src="/assets/matmul_gpu/128x128x8_texld_lock.png" alt="128x128x8_texld_lock" width="100%" style="display:block; margin-left:auto; margin-right:auto" />
<img src="/assets/matmul_gpu/128x128x8_texld.png" alt="128x128x8_texld" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Lastly, we developed an <a href="https://github.com/salykova/sgemm.cu/blob/main/src/kernels/128x256x8.cuh"><code class="language-plaintext highlighter-rouge">128x256x8</code> SGEMM kernel</a> leveraging bigger block size $n_S=256$ and asynchronous copy instructions (<code class="language-plaintext highlighter-rouge">cp.async.ca.shared.global</code>) which are supported starting with the Ampere architecture. The main advantage of these instructions is that one can overlay computation with memory transfers and avoid pipeline stalls. Additionally, they allow to copy data from global memory directly into shared memory bypassing registers:</p>

<p><img src="/assets/matmul_gpu/cp_async.png" alt="cp_async" width="30%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>By simply replacing the normal global load instructions with <code class="language-plaintext highlighter-rouge">cp.async</code> in the <code class="language-plaintext highlighter-rouge">128x128x8</code> kernel results in degraded performance - possibly due to higher latencies of the <code class="language-plaintext highlighter-rouge">cp.async</code> instructions or suboptimal compiler optimizations. However, combining increased block size with <code class="language-plaintext highlighter-rouge">cp.async</code> yields superior results in both speed and power efficiency:
<img src="/assets/matmul_gpu/128x256x8_lock.png" alt="128x256x8_lock" width="100%" style="display:block; margin-left:auto; margin-right:auto" />
<img src="/assets/matmul_gpu/128x256x8.png" alt="128x256x8" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Our final implementation combines the <code class="language-plaintext highlighter-rouge">128x128x8</code> and <code class="language-plaintext highlighter-rouge">128x256x8</code> kernels. For smaller matrices <code class="language-plaintext highlighter-rouge">m=n &lt; 2500</code>, we use the <code class="language-plaintext highlighter-rouge">128x128x8</code> kernel, otherwise, the <code class="language-plaintext highlighter-rouge">128x256x8</code> kernel.</p>]]></content><author><name>Amanzhol Salykov</name></author><summary type="html"><![CDATA[This blog post focuses on a GPU implementation of SGEMM (Single-precision GEneral Matrix Multiply) operation defined as C := alpha*A*B + beta*C. We'll review the algorithm’s design and discuss optimization techniques such as inlined PTX, asynchronous memory copies, double-buffering, avoiding shared memory bank conflicts, and efficient coalesced storage through shared memory.]]></summary></entry><entry><title type="html">Advanced Matrix Multiplication Optimization on NVIDIA GPUs</title><link href="https://salykova.github.io/sgemm-gpu" rel="alternate" type="text/html" title="Advanced Matrix Multiplication Optimization on NVIDIA GPUs" /><published>2025-01-12T09:35:01+00:00</published><updated>2025-01-12T09:35:01+00:00</updated><id>https://salykova.github.io/sgemm-gpu</id><content type="html" xml:base="https://salykova.github.io/sgemm-gpu"><![CDATA[<blockquote>
  <p>This project is inspired by the outstanding works of Andrej Karpathy, George Hotz, Scott Gray, Horace He, Philippe Tillet, Jeremy Howard, Lei Mao and the best CUDA hackers from the <a href="https://github.com/gpu-mode">GPU MODE</a> community (<a href="https://discord.gg/gpumode">Discord server</a>). A special thanks to Mark Saroufim and Andreas Köpf for running GPU MODE and all you’ve done for the community.</p>
</blockquote>

<p>The code is available at <a href="https://github.com/salykova/sgemm.cu">sgemm.cu</a>. This article complements my <a href="https://salykova.github.io/matmul">blog post</a>, which covers the implementation of FP32 matrix multiplication that outperforms BLAS libraries on modern Intel and AMD CPUs. Today we’ll walk through a GPU implementation of SGEMM (Single-precision GEneral Matrix Multiply) operation defined as <code class="language-plaintext highlighter-rouge">C := alpha*A*B + beta*C</code>. The blog delves into benchmarking code on CUDA devices and explains the algorithm’s design along with optimization techniques. These include inlined PTX, asynchronous memory copies, double-buffering, avoiding shared memory bank conflicts, and efficient coalesced storage through shared memory. I’d also like to mention that the high-level algorithm design used in this project was developed by the excellent engineers at NVIDIA and has been extensively studied in prior works on cuBLAS and CUTLASS. My main contribution was translating it into efficient CUDA/PTX code. The goal of this project wasn’t to build an SGEMM that would magically outperform cuBLAS on all GPUs and all matrix sizes. This is especially pointless, given the open-sourced, lightweight CUTLASS library. Instead, the project primarily targets CUDA learners and aims to bridge the gap between the SGEMM implementations explained in books/blogs and those used in NVIDIA’s BLAS libraries. While the implementation is expected to deliver high performance on Ada/Ampere/Volta/Turing devices, it was specifically fine-tuned for and tested on a local NVIDIA RTX 3090 (=GA102 chip: RTX 3080, A10, A40, A6000). The achieved performance is shown below, comparing results with locked and unlocked GPU core frequencies against cuBLAS and Simon Boehm’s highly cited work (used in llamafile, aka tinyBLAS). I plan to continue publishing educational content on high-performance kernels used in AI/ML. Let me know what topics you’d like to see next! Projects currently in development: beating NVIDIA on Tensor Cores, Stream-K GEMM, FlashAttention, xLSTM. If you enjoy educational content like this and would like to see more, please share this article. Your feedback would be greatly appreciated!</p>

<p><strong>P.S. Please feel free to get in touch if you are interested in collaborating. My contact information is available on the homepage.</strong></p>

<p><br />
<img src="/assets/matmul_gpu/unlocked_perf.png" alt="unlocked_perf" width="100%" style="display:block; margin-left:auto; margin-right:auto" />
<br /><br />
<img src="/assets/matmul_gpu/locked_perf.png" alt="locked_perf" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<h2 id="1-introduction">1. Introduction</h2>

<p>I clearly remember Andrej’s post on the current state of the existing cuda learning materials vs. cuda code used in high-performance libraries:</p>

<p><img src="/assets/matmul_gpu/ak_post.png" alt="ak_post" width="65%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Indeed, when it comes to SGEMM implementations, there are some excellent educational blog posts, such as</p>

<ol>
  <li><a href="https://siboehm.com/articles/22/CUDA-MMM">How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance</a> (mentioned by Andrej)</li>
  <li><a href="https://leimao.github.io/article/CUDA-Matrix-Multiplication-Optimization/">CUDA Matrix Multiplication Optimization</a></li>
</ol>

<p>that break down, step by step, how to optimize a CUDA matmul kernel. However, in terms of achieved performance, none of them come close to matching the speed of cuBLAS or CUTLASS, especially when using recent CUDA versions and if benchmarked properly. From my experiments, these implementations achieve 50-70% of cuBLAS’ performance at best. Additionally, I found the explanations in both blog posts a bit overcomplicated in the final optimization steps. Nevertheless, I still think these resources are great for anyone starting with CUDA programming since they provide good foundational knowledge.</p>

<p>On the other hand, I’ve seen some really fast SGEMM implementations with cuBLAS-level performance:</p>

<ol>
  <li><a href="https://github.com/Yinghan-Li/YHs_Sample/tree/master/cuda/gemm">YHs GEMM</a></li>
  <li><a href="https://github.com/tpoisonooo/how-to-optimize-gemm/tree/master">how-to-optimize-gemm</a></li>
</ol>

<p>The problem is that they are undocumented, difficult to find and understand, especially for a CUDA beginner. A similar problem exists with CUTLASS. While it is highly performant, there is a lack of introductory or educational materials explaining how it is internally designed and implemented in efficient CUDA/PTX. Another notable project is <a href="https://github.com/NervanaSystems/maxas">MaxAs</a>, an assembler for the Maxwell architecture developed over a decade ago by Scott Gray. This tool enables programming directly in SASS (the assembly language for NVIDIA GPUs), allowing direct communication with the hardware instead of relying on the hardware-agnostic CUDA/PTX. Using MaxAs, Scott wrote an SGEMM implementation that achieved around 98% of the GM204 chip’s theoretical maximum FLOPS, surpassing cuBLAS by an average of 5%. While the results are impressive, programming in SASS is inflexible and requires deep understanding of the underlying hardware. Furthermore, with significant advancements in the compiler since then, programming directly in SASS is only advantageous in exceptional cases (for example, if you build <a href="https://github.com/tinygrad/tinygrad">tinygrad</a>). CUTLASS achieves performance on par with cuBLAS across various GPU architectures and matrix sizes using only CUDA/PTX code.</p>

<p>But can we actually exceed the cuBLAS barrier? In the following chapters, we will briefly review the high-level SGEMM design used in CUTLASS, and discuss how to translate this design into efficient CUDA/PTX. This guide assumes only a basic knowledge of the CUDA programming model and linear algebra. If you are new to CUDA programming, I strongly recommend starting with these short introductory articles:</p>

<ol>
  <li><a href="https://developer.nvidia.com/blog/easy-introduction-cuda-c-and-c/">An Easy Introduction to CUDA C and C++</a></li>
  <li><a href="https://developer.nvidia.com/blog/how-access-global-memory-efficiently-cuda-c-kernels/">How to Access Global Memory Efficiently in CUDA C/C++ Kernels</a></li>
  <li><a href="https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/">Using Shared Memory in CUDA C/C++</a></li>
  <li><a href="https://developer.nvidia.com/blog/cuda-pro-tip-increase-performance-with-vectorized-memory-access/">Increase Performance with Vectorized Memory Access</a></li>
</ol>

<p>Before we proceed with implementation, let’s talk about benchmarking code on NVIDIA GPUs - a topic often overlooked. Properly benchmarking code is just as important as the code itself, particularly when comparing different implementations.</p>

<h2 id="2-how-to-benchmark-code-on-cuda-devices">2. How to Benchmark Code on CUDA Devices?</h2>

<p>The most reliable way to measure kernel duration is by profiling with NVIDIA Nsight Compute and manually extracting performance data. To obtain deterministic and reproducible results, Nsight Compute automatically applies the following settings:</p>

<ol>
  <li>Clock Control: locks GPU clock frequencies to their base values</li>
  <li>Cache Control: flushes all GPU caches before each replay pass</li>
  <li>Persistence mode</li>
</ol>

<p>Alternatively, you can apply these settings manually and measure kernel duration at runtime without relying on external profilers. On Ubuntu, you can retrieve the base core clock frequency using:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>nvidia-smi base-clocks
</code></pre></div></div>

<p>For instance, on an RTX 3090, the base core clock frequency is 1395 MHz. Next, you’ll need the memory clock frequencies, which work in combination with the base core clock:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>nvidia-smi <span class="nt">-q</span> <span class="nt">-d</span> supported_clocks
</code></pre></div></div>

<p>From the list of supported frequencies, choose the fastest memory clock compatible with the base core frequency. Memory clock speeds are generally more stable than core clock speeds. To lock the clock frequencies and enable persistence mode, run the following commands:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nb">sudo </span>nvidia-smi <span class="nt">--persistence-mode</span><span class="o">=</span>1
<span class="c"># NVIDIA RTX 3090</span>
<span class="nb">sudo </span>nvidia-smi <span class="nt">--lock-gpu-clocks</span><span class="o">=</span>1395
<span class="nb">sudo </span>nvidia-smi <span class="nt">--lock-memory-clocks</span><span class="o">=</span>9501
</code></pre></div></div>

<p>To reset the core and memory clock frequencies, you can use:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nb">sudo </span>nvidia-smi <span class="nt">--reset-gpu-clocks</span>
<span class="nb">sudo </span>nvidia-smi <span class="nt">--reset-memory-clocks</span>
<span class="nb">sudo </span>nvidia-smi <span class="nt">--persistence-mode</span><span class="o">=</span>0
</code></pre></div></div>

<p>GPU clock frequencies may drop due to the GPU’s thermal state, but for high-performance applications, throttling is often caused by power limits. Faulty hardware can also lead to throttling. It’s a good idea to monitor the GPU’s state at least during a test run. Use the following command to keep track of power draw, clock speeds, and throttling reasons in real time:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>watch <span class="nt">-n</span> 0.1 nvidia-smi <span class="nt">--query-gpu</span><span class="o">=</span>power.draw,clocks.sm,clocks.mem,clocks_throttle_reasons.active <span class="nt">--format</span><span class="o">=</span>csv
</code></pre></div></div>

<p>A sample output might look like this:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>308.50 W, 1395 MHz, 9501 MHz, 0x0000000000000000
</code></pre></div></div>

<p>The bit mask <code class="language-plaintext highlighter-rouge">0x0000000000000000</code> indicates no throttling, and the clocks are running at their maximum speeds. A value of <code class="language-plaintext highlighter-rouge">0x0000000000000001</code> indicates an idle state. Any other values suggest throttling is occurring. For a full list of bit mask values and their meanings, refer to the <a href="https://docs.nvidia.com/deploy/nvml-api/group__nvmlClocksThrottleReasons.html">NvmlClocksThrottleReasons documentation</a>.</p>

<p>Once you’ve locked the clock frequencies, you can measure the kernel duration directly in CUDA using CUDA events. Here’s an example:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">cudaEvent_t</span> <span class="n">start</span><span class="p">,</span> <span class="n">stop</span><span class="p">;</span>
<span class="n">cudaEventCreate</span><span class="p">(</span><span class="o">&amp;</span><span class="n">start</span><span class="p">);</span> <span class="n">cudaEventCreate</span><span class="p">(</span><span class="o">&amp;</span><span class="n">stop</span><span class="p">);</span>
<span class="kt">float</span> <span class="n">elapsed_time_ms</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">;</span>

<span class="n">cudaEventRecord</span><span class="p">(</span><span class="n">start</span><span class="p">);</span>
<span class="n">kernel</span><span class="o">&lt;&lt;&lt;</span><span class="p">...</span><span class="o">&gt;&gt;&gt;</span><span class="p">(...);</span>
<span class="n">cudaEventRecord</span><span class="p">(</span><span class="n">stop</span><span class="p">);</span>

<span class="n">cudaEventSynchronize</span><span class="p">(</span><span class="n">stop</span><span class="p">);</span>
<span class="n">cudaEventElapsedTime</span><span class="p">(</span><span class="o">&amp;</span><span class="n">elapsed_time_ms</span><span class="p">,</span> <span class="n">start</span><span class="p">,</span> <span class="n">stop</span><span class="p">);</span>
</code></pre></div></div>

<p>For reliable measurements, multiple replay passes are typically used. In such cases, the GPU cache should be flushed before each kernel replay. This can be done using <code class="language-plaintext highlighter-rouge">cudaMemsetAsync</code> as shown in <a href="https://github.com/NVIDIA/nvbench/blob/main/nvbench/detail/l2flush.cuh">nvbench</a>:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Flush L2 cache</span>
<span class="kt">int</span> <span class="n">dev_id</span><span class="p">{};</span>
<span class="kt">int</span> <span class="n">m_l2_size</span><span class="p">{};</span>
<span class="kt">void</span><span class="o">*</span> <span class="n">buffer</span><span class="p">;</span>
<span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaGetDevice</span><span class="p">(</span><span class="o">&amp;</span><span class="n">dev_id</span><span class="p">));</span>
<span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaDeviceGetAttribute</span><span class="p">(</span><span class="o">&amp;</span><span class="n">m_l2_size</span><span class="p">,</span> <span class="n">cudaDevAttrL2CacheSize</span><span class="p">,</span> <span class="n">dev_id</span><span class="p">));</span>
<span class="k">if</span> <span class="p">(</span><span class="n">m_l2_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaMalloc</span><span class="p">(</span><span class="o">&amp;</span><span class="n">buffer</span><span class="p">,</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="kt">size_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">m_l2_size</span><span class="p">)));</span>
    <span class="kt">int</span><span class="o">*</span> <span class="n">m_l2_buffer</span> <span class="o">=</span> <span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">*&gt;</span><span class="p">(</span><span class="n">buffer</span><span class="p">);</span>
    <span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaMemsetAsync</span><span class="p">(</span><span class="n">m_l2_buffer</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="kt">size_t</span><span class="o">&gt;</span><span class="p">(</span><span class="n">m_l2_size</span><span class="p">)));</span>
    <span class="n">checkCudaErrors</span><span class="p">(</span><span class="n">cudaFree</span><span class="p">(</span><span class="n">m_l2_buffer</span><span class="p">));</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Locking the clock frequencies to their base values is a reliable way to measure the speed of your kernel. However, in real-world scenarios, algorithms don’t typically run with locked clocks. To achieve optimal performance, your algorithm needs to be both fast and power-efficient. The less power your algorithm consumes, the higher the clock speeds your hardware can maintain. NVIDIA GPUs often reduce clock frequencies aggressively, well before hitting their power limits, which can significantly degrade application performance. To account for this, we benchmark our implementation under both locked and unlocked clock conditions, testing for both speed and power efficiency. In our benchmarks, we evaluate matrix sizes ranging from 1024 to 12,800 with a step size of 128. For each matrix size, we launch <code class="language-plaintext highlighter-rouge">1000*exp((-matsize+1024)/3100.0))</code> kernel replays and calculate the execution time as the average of the second half of the replays. For example, given matrix size problem <code class="language-plaintext highlighter-rouge">m=n=k=4096</code>, we run the sgemm <code class="language-plaintext highlighter-rouge">1000*exp((-4096 + 1024)/3100.0))=371</code> times and measure the average duration of the last 185 runs, ensuring the clocks have stabilized. This profiling strategy leads to consistent and reproducible results, even when GPU clocks are unlocked.</p>

<blockquote>
  <p>Avoid using WSL for performance measurements. To ensure accurate and reliable results, please use a native Linux environment.</p>
</blockquote>

<h2 id="3-memory-layout">3. Memory Layout</h2>

<p>Without loss of generality in this implementation, we assume matrices are stored in row-major order. A matrix <code class="language-plaintext highlighter-rouge">A</code> with dimensions <code class="language-plaintext highlighter-rouge">M x N</code> is stored as contiguous array of length <code class="language-plaintext highlighter-rouge">M*N</code>. Elements <code class="language-plaintext highlighter-rouge">A[row][col]</code> are accessed via a 1D raw C pointer <code class="language-plaintext highlighter-rouge">ptr[row*N + col]</code> with <code class="language-plaintext highlighter-rouge">0&lt;=col&lt;=N-1</code> and <code class="language-plaintext highlighter-rouge">0&lt;=row&lt;=M-1</code>. Matrix multiplication is denoted as $C=AB$, where the shapes of matrices $A, B, C$ are $M \times K, K \times N$ and $M \times N$, respectively.
<img src="/assets/matmul_gpu/row_major.png" alt="mem_layout" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>To adapt this implementation for matrices stored in column-major order, simply swap the operands $A$ and $B$, because:</p>

\[C^\text{T} = (A B)^\text{T} = B^\text{T} A^\text{T},\]

<p>Here, $A, B, C$ are matrices stored in row-major order, while $A^\text{T}, B^\text{T}, C^\text{T}$ are the corresponding transposed matrices (i.e., stored in column-major order).</p>

<p>cuBLAS provides an API to calculate SGEMM:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">cublasSgemm</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">lda</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">ldb</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">ldc</span><span class="p">);</span> <span class="c1">// simplified form</span>
</code></pre></div></div>

<p>with <code class="language-plaintext highlighter-rouge">m, n, k</code> denote the matrix sizes $M, N, K$. The parameters <code class="language-plaintext highlighter-rouge">lda, ldb, ldc</code> are the <em>leading dimensions</em> of matrices $A, B, C$, respectively. The leading dimension is the length of the fastest-varying dimension when iterating over the matrix elements (i.e., the length of the first dimension). For matrices stored in row-major order, the leading dimension is usually the number of columns, so typically <code class="language-plaintext highlighter-rouge">lda=k, ldb=n, ldc=n</code>. However, this isn’t always the case. In scenarios where you need to compute a submatrix of a larger matrix, the leading dimension might be larger than the number of columns.</p>

<p>Matrices may be also padded with zeros to support vectorized memory loads or tensor cores. The vectorized load instructions allow to load multiple elements at once using just 1 instruction. Though the vectorized loads reduce total number of instructions and improve bandwidth utilization, they also impose <a href="https://developer.nvidia.com/blog/cuda-pro-tip-increase-performance-with-vectorized-memory-access/">alignment constraint</a> on input data, so that the leading dimension must be divisible by 2 (for 64-bit loads) or 4 (for 128-bit loads). The figure below illustrate the case for 128-bit (=4 floats) loads.</p>

<p><img src="/assets/matmul_gpu/mem_align.png" alt="mem_align" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Note how it’s impossible to load the elements of the first row without touching the elements of the next row if the leading dimension is not divisible by 4. Padding with zeros helps, but requires additional memory. Another solution would be to check at runtime if the leading dimension is divisible by 4. If it is - then use vectorized loads, if not - scalar loads. Additionally, zero padding was commonly used in the past to enable tensor core computations. For instance, in cuBLAS versions &lt; 11, Tensor Core FP16 operations required <code class="language-plaintext highlighter-rouge">m, n, k</code> to be multiples of 8.</p>

<h2 id="4-parallel-thread-execution">4. Parallel Thread Execution</h2>

<p>The CUDA compilation trajectory of a <code class="language-plaintext highlighter-rouge">.cu</code> file looks as follows:</p>

<p><img src="/assets/matmul_gpu/ptxas.png" alt="ptxas" width="60%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>During Stage 1 CUDA code is compiled to PTX (parallel thread execution) <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/">instructions</a> - intermediate high-level code, which can be considered as assembly for a <a href="https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architectures">virtual GPU architecture</a>. Such a virtual GPU is defined entirely by the set of capabilities, or features, that it provides to the application. PTX doesn’t run on any real architecture, directly. It must be optimized and translated to native target-architecture instructions (Stage 2). NVIDIA provides a mechanism to <a href="https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html">insert PTX code into your CUDA program</a>, so that you can mix CUDA/PTX in source code and still have benefits of code optimizations during the PTX generation. By rewriting parts of your code in PTX, you can 1) reduce total number of generated PTX instructions 2) exactly specify PTX instructions you need 3) tune the instructions through qualifiers 4) apply optimizations that are either lacking in the compiler or prohibited by C++ language extensions. <strong>Important! Using inline PTX Assembly will not make your code automatically faster than the one written in CUDA. It will only be faster if your hand-written PTX is better than the generated by the compiler</strong>.</p>

<p>In this implementation we will program some parts of the algorithm directly in PTX, so I highly recommend to check this <a href="https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html">short overview of inline ptx assembly</a> if you have never used it before. The PTX instructions are well documented and can be found at <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#instruction-set">PTX Instruction Set</a>. We will now briefly review the PTX instructions used in this implementation.</p>

<h3 id="41-global-memory-loads">4.1. Global Memory Loads</h3>

<p>For global memory loads we will use <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld">ld.global.f32</a> instruction. Here, “ld” denotes “load” and “f32” - “32-bit float”. The following CUDA code</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span> <span class="c1">// single float register</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="n">reg</span> <span class="o">=</span> <span class="o">*</span><span class="n">gmem_ptr</span><span class="p">;</span> <span class="c1">// global memory -&gt; register transfer</span>
</code></pre></div></div>

<p>can be implemented in PTX as:</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span> <span class="c1">// single float register</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"ld.global.f32 %0, [%1];"</span> <span class="o">:</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg</span><span class="p">)</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">gmem_ptr</span><span class="p">));</span>
</code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">f</code> in <code class="language-plaintext highlighter-rouge">"=f"</code> denotes <code class="language-plaintext highlighter-rouge">float</code> datatype and the <code class="language-plaintext highlighter-rouge">=</code> modifier specifies that the register is written to. The <code class="language-plaintext highlighter-rouge">l</code> represents unsigned 64-bit integer. We also use <code class="language-plaintext highlighter-rouge">volatile</code> keyword to ensure that the instruction is not deleted or moved during generation of PTX.</p>

<h3 id="42-global-memory-stores">4.2. Global Memory Stores</h3>

<p>For global memory stores there is <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st">`st.global.f32</a> instruction:</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span> <span class="c1">// single float register</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="c1">// *gmem_ptr = reg; can be implemented in PTX as:</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"st.global.f32 [%0], %1;"</span> <span class="o">:</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">gmem_ptr</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg</span><span class="p">));</span>
</code></pre></div></div>

<h3 id="43-global-to-shared-memory-transfers">4.3. Global to Shared Memory Transfers</h3>

<p>When you write something like this in CUDA:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__shared__</span> <span class="kt">float</span> <span class="n">smem_ptr</span><span class="p">[</span><span class="n">n</span><span class="p">];</span> <span class="c1">// pointer to shared memory</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="o">*</span><span class="n">smem_ptr</span> <span class="o">=</span> <span class="o">*</span><span class="n">gmem_ptr</span><span class="p">;</span> <span class="c1">// global to shared memory transfer</span>
</code></pre></div></div>

<p>a two-step process occurs. First, the data is fetched from global memory into registers and then that data is copied from registers into shared memory. Additionally the data is cached in all cache levels during the transfer.</p>

<p><img src="/assets/matmul_gpu/standard_ld.png" alt="standard_ld" width="30%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p><em style="display:block; margin-left:auto; margin-right:auto; text-align: center">Global to shared memory transfers</em></p>

<p>For this reason, a global to shared memory transfer in PTX consists of two data movement instructions <code class="language-plaintext highlighter-rouge">ld.global</code> and <code class="language-plaintext highlighter-rouge">st.shared</code>:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__shared__</span> <span class="kt">float</span> <span class="n">smem_ptr</span><span class="p">[</span><span class="n">n</span><span class="p">];</span> <span class="c1">// pointer to shared memory</span>
<span class="kt">uint64_t</span> <span class="n">smem_addr</span><span class="p">;</span>
<span class="c1">// convert generic address to shared address (store location for st.shared instruction)</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"cvta.to.shared.u64 %0, %1;"</span> <span class="o">:</span> <span class="s">"=l"</span><span class="p">(</span><span class="n">smem_addr</span><span class="p">)</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">smem_ptr</span><span class="p">));</span>

<span class="kt">float</span><span class="o">*</span> <span class="n">gmem_ptr</span> <span class="o">=</span> <span class="n">data_in_global_memory</span><span class="p">;</span> <span class="c1">// pointer to global memory</span>
<span class="kt">float</span> <span class="n">buffer</span><span class="p">;</span>
<span class="c1">// global memory -&gt; register</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"ld.global.f32 %0, [%1];"</span> <span class="o">:</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">buffer</span><span class="p">)</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">gmem_ptr</span><span class="p">));</span>
<span class="c1">// register -&gt; shared memory</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"st.shared.f32 [%0], %1;"</span> <span class="o">:</span> <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">smem_addr</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">buffer</span><span class="p">));</span>
</code></pre></div></div>

<blockquote>
  <p>Prior to Ampere architecture it was not possible to transfer data from global memory directly to shared memory mitigating storing in registers. Starting from Ampere architecture, there are <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-asynchronous-copy">asynchronous copy instructions</a> that allow this. The usage of these instructions will be demonstrated later.</p>
</blockquote>

<h3 id="44-vectorized-shared-memory-loads-and-stores">4.4. Vectorized Shared Memory Loads and Stores</h3>

<p>In PTX you can also implement vectorized memory operations (loading/storing multiple elements with one instruction). Here, <code class="language-plaintext highlighter-rouge">v4</code> denotes vector with four elements:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg0</span><span class="p">,</span> <span class="n">reg1</span><span class="p">,</span> <span class="n">reg2</span><span class="p">,</span> <span class="n">reg3</span><span class="p">;</span>
<span class="kt">uint64_t</span> <span class="n">addr</span><span class="p">;</span>
<span class="p">...</span>
<span class="c1">// Shared memory 128-bit loads</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];"</span>
             <span class="o">:</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg0</span><span class="p">),</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg1</span><span class="p">),</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg2</span><span class="p">),</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg3</span><span class="p">)</span>
             <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">addr</span><span class="p">));</span>
<span class="c1">// Shared memory 128-bit stores</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"st.shared.v4.f32 [%0], {%1, %2, %3, %4};"</span>
             <span class="o">:</span>
             <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">addr</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg0</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg1</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg2</span><span class="p">),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">reg3</span><span class="p">));</span>
</code></pre></div></div>

<h3 id="45-predicated-execution">4.5. Predicated Execution</h3>

<p>In PTX conditional executions are implemented using optional guard predicates. The following CUDA code:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">ptr</span><span class="p">;</span> <span class="c1">//pointer to global memory</span>
<span class="kt">unsigned</span> <span class="n">guard</span><span class="p">;</span>
<span class="p">...</span>
<span class="k">if</span> <span class="p">(</span><span class="n">guard</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
    <span class="o">*</span><span class="n">ptr</span> <span class="o">=</span> <span class="n">reg</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div></div>

<p>can be converted to PTX as:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">float</span> <span class="n">reg</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">ptr</span><span class="p">;</span>
<span class="kt">unsigned</span> <span class="n">guard</span><span class="p">;</span>
<span class="p">...</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">".reg .pred p;</span><span class="se">\n\t</span><span class="s">"</span> <span class="c1">// declare predicate 'p'</span>
             <span class="s">".setp.ne.u32 p, %2, 0;</span><span class="se">\n\t</span><span class="s">"</span> <span class="c1">// set 'p' to true if (guard != 0); ne="not equal"</span>
             <span class="s">"@p ld.global.f32 %0, [%1];</span><span class="se">\n\t</span><span class="s">"</span> <span class="c1">// execute instruction if 'p' is true</span>
             <span class="o">:</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">reg</span><span class="p">)</span>
             <span class="o">:</span> <span class="s">"l"</span><span class="p">(</span><span class="n">ptr</span><span class="p">),</span> <span class="s">"r"</span><span class="p">(</span><span class="n">guard</span><span class="p">));</span>
</code></pre></div></div>

<p>We use guard predicates in combination with global load/store instructions to perform global memory access only if it is not out of bounds.</p>

<h2 id="5-sgemm-design">5. SGEMM Design</h2>

<p>Let’s now break down the high-level design of the algorithm. The paper <a href="https://www.semanticscholar.org/paper/Strassen%E2%80%99s-Algorithm-Reloaded-on-GPUs-Huang-Yu/d3214da488806bf4c870080fca18a7f3ecba1e99">Strassen’s Algorithm Reloaded on GPUs</a> contains, in my opinion, one of the best visualizations of the SGEMM design from the CUTLASS library. The SGEMM algorithm can be roughly divided into three main parts:</p>

<ol>
  <li>Transferring data from global to shared memory</li>
  <li>Loading data from shared memory and performing arithmetic operations</li>
  <li>Writing results back to global memory.</li>
</ol>

<p>Each of these steps must be carefully optimized to achieve high overall performance. In the following sections, we’ll explore each step in detail and discuss efficient implementation strategies. It’s worth mentioning that the first step - “transferring data from global memory to shared memory” is the most challenging to grasp. However, once you understand this part, the remaining steps become much easier to follow.</p>

<h3 id="51-transferring-data-from-global-to-shared-memory">5.1. Transferring data from global to shared memory</h3>

<p><img src="/assets/matmul_gpu/matmul_design_tb.png" alt="matmul_gmem_loads" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p><em style="display:block; margin-left:auto; margin-right:auto; text-align: center">Source: <a href="https://www.semanticscholar.org/paper/Strassen%E2%80%99s-Algorithm-Reloaded-on-GPUs-Huang-Yu/d3214da488806bf4c870080fca18a7f3ecba1e99">Strassen’s Algorithm Reloaded on GPUs</a></em></p>

<p>To parallelize $C=AB$ on GPU, the matrix $C$ is partitioned into sub-matrices $\tilde{C}$ of size $m_S \times n_S$ and the sub-matrices are processed in parallel with one thread block computing one sub-matrix $\tilde{C}$ independently from other thread blocks. To compute $\tilde{C}$, we iterate over the dimension $K$. In each iteration, a submatrix $\tilde{A}$ of size $m_s \times k_s$ and a submatrix $\tilde{B}$ of size $k_s \times n_s$ are loaded from <strong>global</strong> into <strong>shared</strong> memory (see the figure above). These submatrices are then multiplied, and the result is used to update $\tilde{C}$ as $\tilde{C} += \tilde{A} \tilde{B}$. The sub-matrices $\tilde{A}, \tilde{B}, \tilde{C}$ are often called <em>blocks</em> or <em>tiles</em>. In total there are $K / k_s$ iterations (assuming the simplest case, where $K$ is divisible by $k_s$). The limited shared memory capacity is the reason why the dimension $K$ is divided into smaller $k_s$ blocks. Full $m_s \times K, K \times n_s$ blocks simply wouldn’t fit available shared memory. For now, don’t be distracted by why the matrices are loaded into shared memory and how exactly the matrices $\tilde{A}, \tilde{B}$ are multiplied, we will discuss it in the next chapter. Let’s focus on the efficient data movement from global to shared memory as our first step towards fast SGEMM.</p>

<p>The pseudo code of the algorithm, from the perspective of a thread block, is as follows:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// The shapes of block_a, block_b, block_c are (ms x ks), (ks x ns), (ms x ns)</span>
<span class="c1">// Each thread block computes one block of C:</span>
<span class="n">block_c</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_a</span><span class="p">[</span><span class="n">block_a_size</span><span class="p">]</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_b</span><span class="p">[</span><span class="n">block_b_size</span><span class="p">]</span>
<span class="k">for</span> <span class="p">(</span><span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="n">K</span><span class="o">/</span><span class="n">ks</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">block_a</span> <span class="o">=</span> <span class="n">load</span> <span class="n">ith</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span> <span class="c1">// from global into shared memory</span>
    <span class="n">block_b</span> <span class="o">=</span> <span class="n">load</span> <span class="n">ith</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span> <span class="c1">// from global into shared memory</span>
    <span class="n">block_c</span> <span class="o">+=</span> <span class="n">block_a</span> <span class="o">*</span> <span class="n">block_b</span> <span class="c1">// compute matrix product and update block_c</span>
<span class="p">}</span>
<span class="n">store</span><span class="p">(</span><span class="n">block_c</span><span class="p">)</span> <span class="c1">// store to global memory</span>
</code></pre></div></div>

<p>Data transfers from global memory to shared memory have significantly higher latency compared to arithmetic operations. During this time, threads are forced to stall, idly waiting for the data needed to compute <code class="language-plaintext highlighter-rouge">block_a * block_b</code>. One way to mitigate this latency is by overlapping data transfers with computations, leveraging instruction-level parallelism (ILP). In GEMM implementations, a technique known as <em>double buffering</em> is commonly used to achieve this overlap:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">block_c</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1">// Shared Memory Double buffering</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_a</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">block_a_size</span><span class="p">]</span> <span class="c1">// 2x shared memory usage</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_b</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">block_b_size</span><span class="p">]</span> <span class="c1">// 2x shared memory usage</span>
<span class="n">block_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span>
<span class="n">block_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span>

<span class="k">for</span> <span class="p">(</span><span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="p">(</span><span class="n">K</span><span class="o">/</span><span class="n">ks</span><span class="o">-</span><span class="mi">1</span><span class="p">);</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">idx</span> <span class="o">=</span> <span class="n">i</span><span class="o">%</span><span class="mi">2</span>
    <span class="n">prefetch_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">%</span><span class="mi">2</span>
    <span class="c1">// prefetch next blocks</span>
    <span class="n">block_a</span><span class="p">[</span><span class="n">prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span>
    <span class="n">block_a</span><span class="p">[</span><span class="n">prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span>
    <span class="c1">// use blocks loaded in previous iteration to calculate matrix product</span>
    <span class="n">block_c</span> <span class="o">+=</span> <span class="n">block_a</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">*</span> <span class="n">block_b</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="p">}</span>
<span class="c1">// final update of the accumulator using last blocks</span>
<span class="n">block_c</span> <span class="o">+=</span> <span class="n">block_a</span><span class="p">[</span><span class="n">prefetch_idx</span><span class="p">]</span> <span class="o">*</span> <span class="n">block_b</span><span class="p">[</span><span class="n">prefetch_idx</span><span class="p">]</span>

<span class="n">store_to_global_memory</span><span class="p">(</span><span class="n">block_c</span><span class="p">)</span>
</code></pre></div></div>

<p>Note that <code class="language-plaintext highlighter-rouge">block_c += block_a[idx] * block_b[idx]</code> doesn’t depend on <code class="language-plaintext highlighter-rouge">blocks[prefetch_idx]</code> allowing the arithmetic instructions to be issued in parallel with the data movement instructions. However, this comes at the cost of doubled shared memory usage, as we need to store two blocks instead of one. The good news is that modern GPUs have <a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability">sufficient shared memory</a> to support double-buffering.</p>

<p>We’ve already introduced several parameters such as block sizes $m_s, k_s, n_s$ and number of threads per thread block. The choice of these parameters highly depends on the shapes of the operands $A, B, C$, as well as the underlying GPU architecture. For example, cuBLAS implements multiple SGEMM kernels optimized for various matrix shapes and GPU architectures. At runtime, it selects the most appropriate kernel using a heuristic approach. The block sizes $m_s, k_s, n_s$ affect not only how the data will be fetched from global memory, but also how the work in all subsequent steps (shared memory loads, arithmetic operations, global memory stores) is organized among the threads to achieve the best possible performance. The choice of the block sizes and the number of threads per thread block also impact shared memory / register usage, which can result in decreased performance if not taken into account. As you might expect, identifying optimal parameter values requires excellent understanding of hardware and extensive experimentation. Fortunately, SGEMM is a well-studied problem and we can use the results from previous studies of cuBLAS and CUTLASS. For large square matrices (<code class="language-plaintext highlighter-rouge">M=N=K &gt; 1024</code>) the combinations of $m_S \times n_S$ such as $128 \times 256$, $128 \times 128$ and $256 \times 128$ lead to optimal performance. From my tests, the configuration $m_s \times n_s \times k_s = 128 \times 128 \times 8$ with <strong>256 threads per thread block</strong> achieved the highest TFLOP/S on my local RTX 3090 for matrix size problems <code class="language-plaintext highlighter-rouge">1024 &lt;= M=N=K &lt;= 2500</code>. Therefore, we will start with implementation of a <code class="language-plaintext highlighter-rouge">128x128x8</code> SGEMM kernel. Now that we know the block dimensions and the number of threads per thread block, let’s discuss how to efficiently organize data loading from global memory and storage into shared memory.</p>

<p>First, we need to load <code class="language-plaintext highlighter-rouge">128x8</code> submatrix $\tilde{A}$ using 256 threads. This results in each thread loading <code class="language-plaintext highlighter-rouge">128*8/256 = 4</code> float elements from global memory. There are several different ways to organize the loading of the block. For global memory reads/stores you always want your accesses to be contiguous or <em>coalesced</em>, so that 32 threads in a wrap access 32 consecutive floats in memory. If a memory access is coalesced the minimum number of memory transactions will be used. However, it is not possible in case of the $\tilde{A}$ block: each row of the block contains only 8 consecutive elements. Nevertheless, even in such cases, consecutive threads in a wrap accessing consecutive elements in memory is preferable and usually results in better performance. The figure below shows how the loading of the block $\tilde{A}$ is implemented. Here, different colors represent different threads, whereas only first 16 threads are shown for simplicity. Four consecutive rows are loaded by 8 consecutive threads: the rows 1-4 are loaded by threads 0-7, the rows 5-8 are loaded by threads 8-15, the rows 9-12 are loaded by threads 16-23 and so on, with the last rows 125-128 are loaded by threads 248-255. We also transpose the block $\tilde{A}$ while storing in shared memory for better memory access pattern during the next computation step. Note how each thread stores 4 consecutive elements in shared memory. This allows us to use PTX vectorized stores <code class="language-plaintext highlighter-rouge">st.shared.v4.f32</code>.</p>

<p><img src="/assets/matmul_gpu/a_gmem_loads.png" alt="a_gmem_loads" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Storing to shared memory using this naive scheme would result in <strong>shared memory bank conflicts</strong>. From the CUDA programming guide:</p>
<blockquote>
  <p>To achieve high bandwidth, shared memory is divided into equally-sized memory modules, called banks, which can be accessed simultaneously. Any memory read or write request made of n addresses that fall in n distinct memory banks can therefore be serviced simultaneously, yielding an overall bandwidth that is n times as high as the bandwidth of a single module.
However, if two addresses of a memory request fall in the same memory bank, there is a bank conflict and the access has to be serialized. The hardware splits a memory request with bank conflicts into as many separate conflict-free requests as necessary, decreasing throughput by a factor equal to the number of separate memory requests. If the number of separate memory requests is n, the initial memory request is said to cause n-way bank conflicts.</p>
</blockquote>

<p>Shared memory has 32 banks that are organized such that successive 32-bit words map to successive banks. Imagine a <code class="language-plaintext highlighter-rouge">float32</code> array of size <code class="language-plaintext highlighter-rouge">8x32</code> stored in <strong>row-major</strong> order as shown below.</p>

<p><img src="/assets/matmul_gpu/bank_conflict.png" alt="bank_conflict" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>In this context, colors and their shades represent memory banks: each row corresponds to 32 distinct memory banks, while each column represents a single memory bank. Here are two important notes about shared memory bank conflicts:</p>

<ol>
  <li>The determination of bank conflicts is made <strong>per memory transaction (or using modern CUDA language - per wave)</strong>, <strong>not</strong> per request, <strong>not</strong> per warp, <strong>not</strong> per instruction.</li>
  <li>Two requests to the same bank and the same 32-bit location in that bank do <strong>not</strong> create a bank conflict (<a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-5-x">illustrated</a> in the CUDA programming guide).</li>
</ol>

<p>When you store (or load) 4 bytes(= 1 float) per thread, which is 4*32=128 bytes per warp, a CUDA device issues a single memory transaction (warp-wide) so that the shared memory access must be conflict-free across the whole wrap(=32 threads). In our case, we store 16 bytes(= 4 floats) per thread using the vector instructions. Warp-wide that will be a total of 512 bytes per request. The GPU splits the request into 4 memory transactions (threads 0-7 make up a transaction, threads 8-15 a transaction and so on), each of which is 128 byte wide. If we would store according to our scheme, then each thread within threads 0-7 would store to the same four columns (red color shades) or with other words to the same four memory banks causing bank conflicts. The same applies for other memory transactions i.e. threads 8-15, threads 16-23 and so on. One possible way to completely avoid bank conflicts would be to pad the leading dimension with 16 bytes (=4 floats) as shown below.</p>

<p><img src="/assets/matmul_gpu/bank_conflict_padding.png" alt="bank_conflict_padding" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Now, if we store the data according to our scheme, each thread within threads 0-7 would accesses distinct memory banks, resulting in 32 memory banks being accessed per memory transaction. The same applies for the remaining memory transactions i.e. t8-t15, t16-t23 and so on. This is the reason why the leading dimension is 132 and not 128 in the implementation:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">const</span> <span class="kt">int</span> <span class="n">smem_a_ld</span> <span class="o">=</span> <span class="mi">132</span><span class="p">;</span> <span class="c1">// 128 + 4</span>
</code></pre></div></div>

<p>To implement double-buffering and store two $\tilde{A}$ blocks, theoretically, we would need shared memory of size <code class="language-plaintext highlighter-rouge">2*132*8*4</code> bytes. However, we increase the size to the nearest power of 2 = <code class="language-plaintext highlighter-rouge">2*256*8*4</code> to enable fast switching. Compare the following code with the pseudocode presented at the beginning of the chapter:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Double-buffering (blocks_b is omitted for simplicity)</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">__align__</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="mi">256</span><span class="o">*</span><span class="mi">8</span><span class="o">*</span><span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">))</span> <span class="n">blocks_a</span><span class="p">[</span><span class="mi">2</span><span class="o">*</span><span class="mi">256</span><span class="o">*</span><span class="mi">8</span><span class="p">]</span>
<span class="kt">uint64_t</span> <span class="n">lds_a_addr</span><span class="p">;</span>
<span class="kt">uint64_t</span> <span class="n">sts_a_addr</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">lds_a_ptr</span> <span class="o">=</span> <span class="n">blocks_a</span><span class="p">;</span> <span class="c1">// lds = load shared</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">sts_a_ptr</span> <span class="o">=</span> <span class="n">blocks_a</span><span class="p">;</span> <span class="c1">// sts = store shared</span>
<span class="n">lds_a_addr</span> <span class="o">=</span> <span class="n">convert_to_addr</span><span class="p">(</span><span class="n">lds_a_ptr</span><span class="p">);</span> <span class="c1">// convert pointer to address for PTX load/store instructions</span>
<span class="n">sts_a_addr</span> <span class="o">=</span> <span class="n">convert_to_addr</span><span class="p">(</span><span class="n">sts_a_ptr</span><span class="p">);</span> <span class="c1">// convert pointer to address for PTX load/store instructions</span>

<span class="c1">// store first block to first half of shared memory</span>
<span class="n">sts_ptx</span><span class="p">(</span><span class="n">sts_a_addr</span><span class="p">);</span>
<span class="c1">// switch address to second half of shared memory</span>
<span class="n">sts_a_addr</span> <span class="o">^=</span> <span class="mi">8192</span><span class="p">;</span>

<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="p">(</span><span class="n">K</span><span class="o">/</span><span class="n">ks</span><span class="o">-</span><span class="mi">1</span><span class="p">);</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="p">...</span>
    <span class="c1">// store next block to second(first) half of shared memory</span>
    <span class="n">sts_ptx</span><span class="p">(</span><span class="n">sts_a_addr</span><span class="p">);</span>
    <span class="p">...</span>
    <span class="c1">// load block from first(second) half of shared memory to compute c+=block_a*block_b</span>
    <span class="n">lds_ptx</span><span class="p">(</span><span class="n">lds_a_addr</span><span class="p">);</span>
    <span class="p">...</span>
    <span class="c1">// swap the addresses for next iteration: lds_a_addr = sts_a_addr, sts_a_addr = lds_a_addr</span>
    <span class="n">lds_a_addr</span> <span class="o">^=</span> <span class="mi">8192</span><span class="p">;</span>
    <span class="n">sts_a_addr</span> <span class="o">^=</span> <span class="mi">8192</span><span class="p">;</span>
    <span class="p">...</span>
<span class="p">}</span>
<span class="p">...</span>
</code></pre></div></div>

<p><img src="/assets/matmul_gpu/db.png" alt="db" width="70%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>First, we require <code class="language-plaintext highlighter-rouge">blocks_a</code> to be <code class="language-plaintext highlighter-rouge">2*256*8*4</code>=<code class="language-plaintext highlighter-rouge">2^14</code>=<code class="language-plaintext highlighter-rouge">16384</code>-byte aligned. This implies the address of the first element of <code class="language-plaintext highlighter-rouge">blocks_a</code> to be divisible by 16384 or with other words the last 14 bits of the address are zero:</p>

<p><img src="/assets/matmul_gpu/bit_repr.png" alt="bit_repr" width="60%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>As each block size is <code class="language-plaintext highlighter-rouge">8192=2^13</code> bytes, switching between the blocks can now be implemented with just a single XOR instruction <code class="language-plaintext highlighter-rouge">^= 8192</code>. The only drawback of this method is the unused shared memory (in this case <code class="language-plaintext highlighter-rouge">2*8*128*4</code> bytes). However, this can be ignored considering <a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability">maximum amount of shared memory per thread block</a> on modern GPUs.</p>

<p>Loading and storing a <code class="language-plaintext highlighter-rouge">8 x 128</code> submatrix $\tilde{B}$ is much simpler to manage due to its shape. Since the sub-matrix must not be transposed, the loading and storing schemes are identical:</p>

<p><img src="/assets/matmul_gpu/b_gmem_loads.png" alt="b_gmem_loads" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>We use 32 consecutive threads to load 32 consecutive elements, with each thread loading 4 elements, spaced apart by a stride of 32. Note that since we store data in 32 distinct shared memory banks, no padding is required, and bank conflicts are avoided. Furthermore, the block size <code class="language-plaintext highlighter-rouge">128*8</code> is naturally a power of two, eliminating the need for additional padding and allowing block switching with a single XOR <code class="language-plaintext highlighter-rouge">^=4096</code> instruction.</p>

<h3 id="52-shared-memory-loads-and-arithmetic-operations">5.2. Shared Memory Loads and Arithmetic Operations</h3>

<p>With blocks $\tilde{A}$ and $\tilde{B}$ now residing in shared memory, let’s discuss how to efficiently load from shared memory and compute block $\tilde{C}$. To do this, we’ll dive one level deeper into our parallelization strategy and describe the algorithm from a warp’s perspective:</p>

<p><img src="/assets/matmul_gpu/warp_level_design.png" alt="warp_level_design" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Launched thread block consists of 256 threads, which corresponds to <code class="language-plaintext highlighter-rouge">256/32=8</code> warps. The block $\tilde{C}$, with dimensions $128 \times 128$, is, therefore, divided into 8 regions $\tilde{C}_W$ labeled $W1, …, W8$ in the figure. Each region $\tilde{C}_W$ has dimensions $m_W \times n_W = 32 \times 64$ and is computed by a single warp: $W1$ is computed by threads <code class="language-plaintext highlighter-rouge">t0-t31</code>, $W2$ is computed by threads <code class="language-plaintext highlighter-rouge">t32-t63</code>, and so on, with $W8$ computed by threads <code class="language-plaintext highlighter-rouge">t224-t255</code>. The figure above uses $W8$ as an example to demonstrate how a single $\tilde{C}_W$ region is computed. We iterate over the dimension $K$ and in each iteration we</p>

<ol>
  <li>load <code class="language-plaintext highlighter-rouge">fragment_a</code> (=column of size $m_W \times 1$) from $\tilde{A}$ into registers</li>
  <li>load <code class="language-plaintext highlighter-rouge">fragment_b</code> (=row of size $1 \times n_W$) from $\tilde{B}$ into registers</li>
  <li>multiply the fragments and update $\tilde{C}_W$</li>
</ol>

<p>As $k_S = 8$, there will be in total 8 iterations. This explanation is from the perspective of a warp. Now, let’s delve one final level deeper and examine how the work within a warp is distributed among its 32 threads.</p>

<p><img src="/assets/matmul_gpu/thread_level_design.png" alt="thread_level_design" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Each thread in a wrap computes four <code class="language-plaintext highlighter-rouge">4x4</code> sub-matrices (=accumulators) within $\tilde{C}_W$ or if concatenated - <code class="language-plaintext highlighter-rouge">8x8</code> accumulator. To do this, each thread loads 8 elements from <code class="language-plaintext highlighter-rouge">fragment_a</code>, 8 elements from <code class="language-plaintext highlighter-rouge">fragment_b</code> (as illustrated for thread <code class="language-plaintext highlighter-rouge">t0</code> in the figure), multiplies them and updates the accumulator using fused multiply-add (FMA) instructions. Since <code class="language-plaintext highlighter-rouge">block_a</code> was transposed in the previous step, the elements in <code class="language-plaintext highlighter-rouge">fragment_a</code> are stored contiguously in memory, allowing faster access through vectorized loads. The threads are arranged in a way that avoids bank conflicts and works around NVIDIA’s shared memory broadcast limitation. This limitation occurs when 4 floats loaded using 16-byte vector instruction must be broadcast to more than 4 consecutive threads within a warp.</p>

<p>Bringing everything together, the entire SGEMM algorithm can be visualized as follows:
<img src="/assets/matmul_gpu/cutlass_sgemm.png" alt="cutlass_sgemm" width="70%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>As you might expect, the accumulators are frequently updated during the computation and need to be stored in the fastest memory - the register files. Each thread allocates <code class="language-plaintext highlighter-rouge">float accumulator[8][8]</code>, so that the entire block $\tilde{C}$ of size $128 \times 128$ is stored in registers by the <code class="language-plaintext highlighter-rouge">256</code> threads. This works because <code class="language-plaintext highlighter-rouge">256=16*16</code>, and the combined arrangement <code class="language-plaintext highlighter-rouge">(16*8)x(16*8)=128x128</code> matches the size of $\tilde{C}$. Just as we used double buffering to load the blocks $\tilde{A}$ and $\tilde{B}$ (from global memory to shared memory), we now also double buffer the fragments to minimize memory transfer latencies when moving data from shared memory to registers. The pseudocode for the algorithm can be written as follows:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Pseudocode</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_a</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">block_a_size</span><span class="p">]</span>
<span class="n">__shared__</span> <span class="kt">float</span> <span class="n">block_b</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="n">block_b_size</span><span class="p">]</span>
<span class="kt">float</span> <span class="n">fragment_a</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="mi">8</span><span class="p">]</span>
<span class="kt">float</span> <span class="n">fragment_b</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="mi">8</span><span class="p">]</span>
<span class="kt">float</span> <span class="n">accumulator</span><span class="p">[</span><span class="mi">8</span><span class="p">][</span><span class="mi">8</span><span class="p">]</span>

<span class="n">block_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span>
<span class="n">block_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span>
<span class="n">fragment_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">fragment_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

<span class="k">for</span> <span class="p">(</span><span class="n">block_k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">block_k</span><span class="o">&lt;</span><span class="p">(</span><span class="n">K</span><span class="o">/</span><span class="n">ks</span><span class="o">-</span><span class="mi">1</span><span class="p">);</span> <span class="n">block_k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">block_idx</span> <span class="o">=</span> <span class="n">block_k</span> <span class="o">%</span> <span class="mi">2</span>
    <span class="n">block_prefetch_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">block_k</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
    <span class="c1">// prefetch next blocks (Shared Memory Double buffering)</span>
    <span class="n">block_a</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">A</span>
    <span class="n">block_a</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">block</span> <span class="n">of</span> <span class="n">matrix</span> <span class="n">B</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">warp_k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">warp_k</span><span class="o">&lt;</span><span class="mi">8</span><span class="p">;</span> <span class="n">warp_k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">frag_idx</span> <span class="o">=</span> <span class="n">warp_k</span> <span class="o">%</span> <span class="mi">2</span>
        <span class="n">frag_prefetch_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">warp_k</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
        <span class="c1">// prefetch next fragments (Register Double buffering)</span>
        <span class="n">fragment_a</span><span class="p">[</span><span class="n">frag_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_a</span><span class="p">[</span><span class="n">block_idx</span><span class="p">]</span>
        <span class="n">fragment_b</span><span class="p">[</span><span class="n">frag_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_b</span><span class="p">[</span><span class="n">block_idx</span><span class="p">]</span>
        <span class="c1">// use fragments loaded in previous iteration to calculate matrix product</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">8</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="mi">8</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
                <span class="n">accumulator</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">fragment_a</span><span class="p">[</span><span class="n">frag_idx</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">fragment_b</span><span class="p">[</span><span class="n">frag_idx</span><span class="p">][</span><span class="n">j</span><span class="p">];</span>
            <span class="p">}</span>
        <span class="p">}</span>
    <span class="p">}</span>
    <span class="n">fragment_a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_a</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span>
    <span class="n">fragment_b</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">first</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_b</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span>
<span class="p">}</span>

<span class="c1">// final update of the accumulator using last blocks</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">warp_k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">warp_k</span><span class="o">&lt;</span><span class="mi">8</span><span class="p">;</span> <span class="n">warp_k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">frag_idx</span> <span class="o">=</span> <span class="n">warp_k</span> <span class="o">%</span> <span class="mi">2</span>
    <span class="n">frag_prefetch_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">warp_k</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
    <span class="c1">// prefetch next fragments (Register Double buffering)</span>
    <span class="n">fragment_a</span><span class="p">[</span><span class="n">frag_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_a</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span>
    <span class="n">fragment_b</span><span class="p">[</span><span class="n">frag_prefetch_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">load</span> <span class="n">next</span> <span class="n">fragment</span> <span class="n">from</span> <span class="n">block_b</span><span class="p">[</span><span class="n">block_prefetch_idx</span><span class="p">]</span>
    <span class="c1">// use fragments loaded in previous iteration to calculate matrix product</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">8</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="mi">8</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">accumulator</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">fragment_a</span><span class="p">[</span><span class="n">frag_idx</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">fragment_b</span><span class="p">[</span><span class="n">frag_idx</span><span class="p">][</span><span class="n">j</span><span class="p">];</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>

<span class="c1">// After completing the matrix multiplication C=A*B, we perform one final update to the accumulator</span>
<span class="c1">// to compute  C=alpha*A*B before storing the result back to global memory:</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o">&lt;</span><span class="mi">8</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">j</span><span class="o">&lt;</span><span class="mi">8</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">accumulator</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">*=</span> <span class="n">alpha</span><span class="p">;</span>
    <span class="p">}</span>
<span class="p">}</span>

<span class="n">store_to_global_memory</span><span class="p">(</span><span class="n">accumulator</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="53-coalesced-global-memory-stores-through-shared-memory">5.3. Coalesced Global Memory Stores Through Shared Memory</h3>

<p>Just as with global memory reads, we want our global memory writes to be coalesced. However, directly storing the accumulators to global memory based on our current mapping</p>

<p><img src="/assets/matmul_gpu/acc_map.png" alt="acc_map" width="70%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>would result in random memory accesses, significantly hurting performance. To fix this, we use shared memory as a buffer to rearrange the accumulators, enabling coalesced global memory writes. At this stage, the accumulators have already been computed, so we no longer need shared memory for computation. Transferring data from registers to shared memory is fast. The overhead of these additional transfers from registers to shared memory is negligible compared to the performance gains achieved through coalesced writes. We write the accumulator’s elements to shared memory row by row according to the following scheme:</p>

<p><img src="/assets/matmul_gpu/stgx.png" alt="stg" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The first row, containing 32 elements, is copied to the first 32 consecutive memory addresses in shared memory. Similarly, the second row is copied to the next 32 consecutive memory addresses, and so on with all 16 rows have been copied to shared memory. Next, we iterate through the rows in shared memory, and in each iteration, we store a row (containing 32 elements) to global memory using coalesced writes:</p>

<p><img src="/assets/matmul_gpu/stg_final.png" alt="stg_final" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The process is then repeated for the other three <code class="language-plaintext highlighter-rouge">4x4</code> accumulators of the threads.</p>

<p>To compute $C := \alpha AB + \beta C$, we make a slight adjustment to the process of storing the data to global memory. After copying the accumulator from registers to shared memory, we check if <code class="language-plaintext highlighter-rouge">beta != 0.0</code>. If true, we load (using coalesced loads) the corresponding element from global memory into a register, multiply it by <code class="language-plaintext highlighter-rouge">beta</code> and add the result to the accumulator stored in shared memory. Finally, we store the updated accumulator <code class="language-plaintext highlighter-rouge">alpha*A*B+beta*C</code> from shared memory to global memory using coalesced writes.</p>

<h3 id="6-performance-analysis">6. Performance Analysis</h3>

<p>So far, we have discussed the design of the <code class="language-plaintext highlighter-rouge">128x128x8</code> SGEMM kernel. Its implementation is available at <a href="https://github.com/salykova/sgemm.cu/blob/main/src/kernels/128x128x8.cuh">128x128x8.cuh</a> and closely follows the pseudo-code outlined earlier. Let’s now benchmark this kernel to evaluate its performance. First, we conduct a benchmark with locked clock frequencies:
<img src="/assets/matmul_gpu/128x128x8_lock.png" alt="128x128x8_lock" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The benchmark results show that the implementation outperforms cuBLAS when clock speeds remain constant. However, performance alone is not enough; we also need to consider power consumption. To evaluate both metrics, we run the benchmark with unlocked clock frequencies:
<img src="/assets/matmul_gpu/128x128x8.png" alt="128x128x8" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>This reveals the effect of throttling due to reaching power limits. While the <code class="language-plaintext highlighter-rouge">128x128x8</code> kernel is, on average, 3–4% faster than cuBLAS, it consumes 12% more power. The increased power consumption causes the GPU to operate near the power limit for matrix sizes <code class="language-plaintext highlighter-rouge">m=n=k&gt;4000</code>, resulting in reduced clock speeds and overall performance degradation. This the reason why optimizing <strong>both</strong> running time and power consumption is required for achieving a balanced and efficient implementation.</p>

<p>We can slightly improve the running time of the kernel by utilizing vectorized global texture loads. The new kernel is available at <a href="https://github.com/salykova/sgemm.cu/blob/main/src/kernels/128x128x8_texld.cuh">128x128x8_texld</a>. Since the vectorized load instructions impose alignment constraints on the input data, we first verify the memory alignment and ensure the leading dimensions of matrices <code class="language-plaintext highlighter-rouge">A</code> and <code class="language-plaintext highlighter-rouge">B</code> are divisible by 4:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">bool</span> <span class="n">is_aligned</span> <span class="o">=</span> <span class="p">(((</span><span class="kt">unsigned</span><span class="p">)</span><span class="n">lda</span> <span class="o">&amp;</span> <span class="mi">3u</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;&amp;</span> <span class="p">(((</span><span class="kt">unsigned</span><span class="p">)</span><span class="n">ldb</span> <span class="o">&amp;</span> <span class="mi">3u</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span>
                    <span class="o">&amp;&amp;</span> <span class="p">(((</span><span class="kt">unsigned</span> <span class="kt">long</span><span class="p">)</span><span class="n">A</span> <span class="o">&amp;</span> <span class="mi">15u</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;&amp;</span> <span class="p">(((</span><span class="kt">unsigned</span> <span class="kt">long</span><span class="p">)</span><span class="n">B</span> <span class="o">&amp;</span> <span class="mi">15u</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">);</span>
</code></pre></div></div>

<p>If the input data is aligned, we can use the vectorized load instructions. First we need to create texture objects, texture descriptors, and resource descriptors. These are configured to handle <code class="language-plaintext highlighter-rouge">float</code> data type with four 32-bit channels (x, y, z, w). The texture objects are then bound to the operands <code class="language-plaintext highlighter-rouge">A, B</code>, and passed to the kernel instead of raw pointers <code class="language-plaintext highlighter-rouge">A, B</code>.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">cudaResourceDesc</span> <span class="n">resDesc</span><span class="p">;</span>
<span class="n">cudaTextureDesc</span> <span class="n">texDesc</span><span class="p">;</span>
<span class="n">cudaTextureObject_t</span> <span class="n">tex_a</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="n">cudaTextureObject_t</span> <span class="n">tex_b</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="p">...</span>
<span class="k">if</span> <span class="p">(</span><span class="n">is_aligned</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">memset</span><span class="p">(</span><span class="o">&amp;</span><span class="n">texDesc</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="k">sizeof</span><span class="p">(</span><span class="n">texDesc</span><span class="p">));</span>
    <span class="n">texDesc</span><span class="p">.</span><span class="n">readMode</span> <span class="o">=</span> <span class="n">cudaReadModeElementType</span><span class="p">;</span>
    <span class="n">texDesc</span><span class="p">.</span><span class="n">normalizedCoords</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
    <span class="n">memset</span><span class="p">(</span><span class="o">&amp;</span><span class="n">resDesc</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="k">sizeof</span><span class="p">(</span><span class="n">resDesc</span><span class="p">));</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">resType</span> <span class="o">=</span> <span class="n">cudaResourceTypeLinear</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">f</span> <span class="o">=</span> <span class="n">cudaChannelFormatKindFloat</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">x</span> <span class="o">=</span> <span class="mi">32</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">y</span> <span class="o">=</span> <span class="mi">32</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">z</span> <span class="o">=</span> <span class="mi">32</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">desc</span><span class="p">.</span><span class="n">w</span> <span class="o">=</span> <span class="mi">32</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">devPtr</span> <span class="o">=</span> <span class="n">A</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">sizeInBytes</span> <span class="o">=</span> <span class="n">m</span> <span class="o">*</span> <span class="n">lda</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">);</span>
    <span class="n">cudaCreateTextureObject</span><span class="p">(</span><span class="o">&amp;</span><span class="n">tex_a</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">resDesc</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">texDesc</span><span class="p">,</span> <span class="nb">NULL</span><span class="p">);</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">devPtr</span> <span class="o">=</span> <span class="n">B</span><span class="p">;</span>
    <span class="n">resDesc</span><span class="p">.</span><span class="n">res</span><span class="p">.</span><span class="n">linear</span><span class="p">.</span><span class="n">sizeInBytes</span> <span class="o">=</span> <span class="n">k</span> <span class="o">*</span> <span class="n">ldb</span> <span class="o">*</span> <span class="k">sizeof</span><span class="p">(</span><span class="kt">float</span><span class="p">);</span>
    <span class="n">cudaCreateTextureObject</span><span class="p">(</span><span class="o">&amp;</span><span class="n">tex_b</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">resDesc</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">texDesc</span><span class="p">,</span> <span class="nb">NULL</span><span class="p">);</span>
    <span class="n">sgemm_texld_128x128x8</span><span class="o">&lt;&lt;&lt;</span><span class="n">grid</span><span class="p">,</span> <span class="n">threads</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">m</span><span class="p">,</span>
                                             <span class="n">n</span><span class="p">,</span>
                                             <span class="n">k</span><span class="p">,</span>
                                             <span class="o">*</span><span class="n">alpha</span><span class="p">,</span>
                                             <span class="n">tex_a</span><span class="p">,</span>
                                             <span class="n">lda</span><span class="p">,</span>
                                             <span class="n">tex_b</span><span class="p">,</span>
                                             <span class="n">ldb</span><span class="p">,</span>
                                             <span class="o">*</span><span class="n">beta</span><span class="p">,</span>
                                             <span class="n">C</span><span class="p">,</span>
                                             <span class="n">ldc</span><span class="p">);</span>
    <span class="n">cudaDestroyTextureObject</span><span class="p">(</span><span class="n">tex_a</span><span class="p">);</span>
    <span class="n">cudaDestroyTextureObject</span><span class="p">(</span><span class="n">tex_b</span><span class="p">);</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Within the kernel, we load data through texture objects using the <code class="language-plaintext highlighter-rouge">tex1Dfetch</code> function, which compiles to a single PTX instruction:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">float4</span> <span class="n">texld_a_buffer</span><span class="p">;</span>
<span class="n">texld_a_buffer</span> <span class="o">=</span> <span class="n">tex1Dfetch</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">&gt;</span><span class="p">(</span><span class="n">tex_a</span><span class="p">,</span> <span class="n">texld_a_offset</span><span class="p">);</span>
</code></pre></div></div>

<p>We use global texture loads over normal vectorized global loads (<code class="language-plaintext highlighter-rouge">ld.global.v4.f32</code>) because texture loads handle out-of-bounds reads gracefully by returning zeros, avoiding the need for predicated execution. This simplification leads to more efficient code:
<img src="/assets/matmul_gpu/128x128x8_texld_lock.png" alt="128x128x8_texld_lock" width="100%" style="display:block; margin-left:auto; margin-right:auto" />
<img src="/assets/matmul_gpu/128x128x8_texld.png" alt="128x128x8_texld" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Lastly, we developed an <a href="https://github.com/salykova/sgemm.cu/blob/main/src/kernels/128x256x8.cuh"><code class="language-plaintext highlighter-rouge">128x256x8</code> SGEMM kernel</a> leveraging bigger block size $n_S=256$ and asynchronous copy instructions (<code class="language-plaintext highlighter-rouge">cp.async.ca.shared.global</code>) which are supported starting with the Ampere architecture. The main advantage of these instructions is that one can overlay computation with memory transfers and avoid pipeline stalls. Additionally, they allow to copy data from global memory directly into shared memory bypassing registers:</p>

<p><img src="/assets/matmul_gpu/cp_async.png" alt="cp_async" width="30%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>By simply replacing the normal global load instructions with <code class="language-plaintext highlighter-rouge">cp.async</code> in the <code class="language-plaintext highlighter-rouge">128x128x8</code> kernel results in degraded performance - possibly due to higher latencies of the <code class="language-plaintext highlighter-rouge">cp.async</code> instructions or suboptimal compiler optimizations. However, combining increased block size with <code class="language-plaintext highlighter-rouge">cp.async</code> yields superior results in both speed and power efficiency:
<img src="/assets/matmul_gpu/128x256x8_lock.png" alt="128x256x8_lock" width="100%" style="display:block; margin-left:auto; margin-right:auto" />
<img src="/assets/matmul_gpu/128x256x8.png" alt="128x256x8" width="100%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Our final implementation combines the <code class="language-plaintext highlighter-rouge">128x128x8</code> and <code class="language-plaintext highlighter-rouge">128x256x8</code> kernels. For smaller matrices <code class="language-plaintext highlighter-rouge">m=n &lt; 2500</code>, we use the <code class="language-plaintext highlighter-rouge">128x128x8</code> kernel, otherwise, the <code class="language-plaintext highlighter-rouge">128x256x8</code> kernel.</p>]]></content><author><name>Amanzhol Salykov</name></author><summary type="html"><![CDATA[This blog post focuses on a GPU implementation of SGEMM (Single-precision GEneral Matrix Multiply) operation defined as C := alpha*A*B + beta*C. We'll review the algorithm’s design and discuss optimization techniques such as inlined PTX, asynchronous memory copies, double-buffering, avoiding shared memory bank conflicts, and efficient coalesced storage through shared memory.]]></summary></entry><entry><title type="html">Advanced Matrix Multiplication Optimization on Modern Multi-Core Processors</title><link href="https://salykova.github.io/matmul-cpu" rel="alternate" type="text/html" title="Advanced Matrix Multiplication Optimization on Modern Multi-Core Processors" /><published>2025-01-12T09:00:01+00:00</published><updated>2025-01-12T09:00:01+00:00</updated><id>https://salykova.github.io/matmul-cpu</id><content type="html" xml:base="https://salykova.github.io/matmul-cpu"><![CDATA[<p><strong>TL;DR</strong> The code is available at <a href="https://github.com/salykova/sgemm.c">sgemm.c</a>. This blog post walks through optimizing multi-threaded FP32 matrix multiplication on modern processors using FMA3 and AVX2 vector instructions. The implementation delivers strong performance on a variety of x86-64 CPUs, both in single-threaded and multithreaded scenarios. However, to reach peak performance, you’ll need to fine-tune hyperparameters - such as the <em>number of threads, kernel size, and tile sizes</em>. Additionally, on AVX-512 CPUs, the BLAS libraries might be notably faster due to AVX-512 instructions, which were intentionally omitted here to support a broader range of processors. Performance results for Intel Core Ultra 265 and AMD Ryzen 7 9700X are shown below.</p>

<p><strong>P.S. Please feel free to get in touch if you are interested in collaborating. My contact information is available on the homepage.</strong></p>

<p><img src="/assets/matmul_cpu/intel_core_perf.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p><img src="/assets/matmul_cpu/amd_ryzen_perf.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<h2 id="1-introduction">1. Introduction</h2>

<p>Matrix multiplication is an essential part of nearly all modern neural networks. Despite using matmul daily in PyTorch, NumPy, or JAX, I’ve never really thought about how it is designed and implemented internally to take full advantage of hardware capabilities. NumPy, for instance, relies on external <a href="https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms">BLAS</a> (Basic Linear Algebra Subprograms) libraries. These libraries contain high-performance, optimized implementations of common linear algebra operations, such as the dot product, matrix multiplication, vector addition, and scalar multiplication. Examples of BLAS libraries include:</p>

<ol>
  <li><a href="https://en.wikipedia.org/wiki/Math_Kernel_Library">Intel MKL</a> - optimized for Intel CPUs</li>
  <li><a href="https://developer.apple.com/documentation/accelerate">Accelerate</a> - optimized for Apple CPUs</li>
  <li><a href="https://en.wikipedia.org/wiki/BLIS_(software)">BLIS</a> - open-source, multi-vendor, BLAS-like Library Instantiation Software</li>
  <li><a href="https://en.wikipedia.org/wiki/GotoBLAS">GotoBLAS</a> - open-source, multi-vendor</li>
  <li><a href="https://en.wikipedia.org/wiki/OpenBLAS">OpenBLAS</a> - open-source, multi-vendor, fork of GotoBLAS</li>
</ol>

<p>A closer look at the OpenBLAS <a href="https://github.com/OpenMathLib/OpenBLAS/blob/develop/kernel/x86_64/sgemm_kernel_8x4_haswell.c">code</a> reveals a mix of C and low-level assembly. In fact, OpenBLAS, GotoBLAS, and BLIS are written in C/FORTRAN/Assembly and contain matmul implementations manually optimized for different CPU microarchitectures.
My goal was to implement the matrix multiplication in pure C (without low-level assembly code) so that it works for any matrix size, runs on all modern x86-64 processors, and competes with existing BLAS libraries. At the sime time I wanted to keep the code simple and easy to extend. After some research, I found a few great step-by-step tutorials on implementing fast matrix multiplication from scratch, covering both theory and practice:</p>

<ol>
  <li><a href="https://siboehm.com/articles/22/Fast-MMM-on-CPU">Fast Multidimensional Matrix Multiplication on CPU from Scratch</a> by Simon Boehm.</li>
  <li><a href="https://en.algorithmica.org/hpc/algorithms/matmul/">Matrix Multiplication</a> by Sergey Slotin.</li>
  <li><a href="https://en.wikipedia.org/wiki/George_Hotz">Geohot’s</a> stream <a href="https://www.youtube.com/watch?v=VgSQ1GOC86s">Can you multiply a matrix?</a></li>
</ol>

<p>I highly recommend these clear and well-explained tutorials with alternative implementations. They helped me better understand the topic and, in some sense, motivated me to write my own implementation. The reason is that all three solutions above work only for specific matrix sizes and do not achieve performance of the BLAS libraries. Unsatisfied with these results, I kept researching and came across two fascinating papers: “<a href="https://www.cs.utexas.edu/~flame/pubs/GotoTOMS_final.pdf">Anatomy of High-Performance Matrix Multiplication</a>” and “<a href="https://www.cs.utexas.edu/~flame/pubs/blis3_ipdps14.pdf">Anatomy of High-Performance Many-Threaded Matrix Multiplication</a>”. The first introduces GotoBLAS, a high-performance BLAS implementation by <a href="https://en.wikipedia.org/wiki/Kazushige_Goto">Kazushige Goto</a>. The second reviews the matmul design used in the BLIS library (an extended version of GotoBLAS) and explores different parallelization strategies. Due to its superior high-level design, I had a feeling that the matmul implementation from the BLIS library can outperform existing BLAS implementations even if written in pure C and not manually finetuned using inline assembly. In the next chapters we’ll step-by-step implement the algorithm from scratch and compare against OpenBLAS. Before diving into optimizations, let’s first go over how to install OpenBLAS and properly benchmark the code on a CPU.</p>

<h2 id="2-how-to-install-and-benchmark-openblas">2. How to Install and Benchmark OpenBLAS</h2>

<p>I benchmarked the code on the following machine:</p>

<ul>
  <li>CPU: AMD Ryzen 7 9700X</li>
  <li>RAM: 32GB DDR5 6000 MHz CL36</li>
  <li>OpenBLAS 0.3.26</li>
  <li>Compiler: GCC 13.3</li>
  <li>Compiler flags: <code class="language-plaintext highlighter-rouge">-O3 -march=native -mno-avx512f -fopenmp</code></li>
  <li>OS: Ubuntu 24.04.1 LTS</li>
</ul>

<p><strong>Important!</strong> To obtain reproducible and accurate results, minimize the number of active tasks, particularly when benchmarking multi-threaded code. Windows systems generally deliver lower performance compared to Linux due to higher number of active background tasks.</p>

<p>To benchmark OpenBLAS, start by installing it according to the <a href="https://github.com/OpenMathLib/OpenBLAS/wiki/Installation-Guide">installation guide</a>. During installation, make sure to set an appropriate <code class="language-plaintext highlighter-rouge">TARGET</code> and disable AVX512 instructions for a fair comparison. For Zen4/5 processors compile OpenBLAS with:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>make <span class="nv">TARGET</span><span class="o">=</span>ZEN
</code></pre></div></div>

<p>Otherwise, OpenBLAS defaults to AVX512 instructions. After installation, you can run FP32 matmul using the OpenBLAS API:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;cblas.h&gt;</span><span class="cp">
</span><span class="n">cblas_sgemm</span><span class="p">(</span><span class="n">CblasColMajor</span><span class="p">,</span> <span class="n">CblasNoTrans</span><span class="p">,</span> <span class="n">CblasNoTrans</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">m</span><span class="p">);</span>
</code></pre></div></div>

<p>The benchmark evaluates the custom implementation and the OpenBLAS API on square matrices, ranging from <code class="language-plaintext highlighter-rouge">m=n=k=200</code> to <code class="language-plaintext highlighter-rouge">m=n=k=10000</code> in steps of <code class="language-plaintext highlighter-rouge">200</code>. To obtain consistent and accurate results, matrix multiplication is repeated <code class="language-plaintext highlighter-rouge">n_iter</code> times, and performance is measured as median execution time.</p>

<p>To multiply two <code class="language-plaintext highlighter-rouge">float32</code> matrices - $A$ of size $M \times K$ and $B$ of size $K \times N$, for each element of the resulting matrix $C$ of size $M \times N$, we need to compute the dot product between a row of $A$ and a column of $B$. This requires $K$ (additions) + $K$ (multiplications) = $2K$ Floating Point Operations (FLOP) per element of $C$ or $2MNK$ FLOP in total. A metric often used to evaluate matmul performance is called FLOP per second or FLOP/s or FLOPS, and it can be derived from the execution time as <code class="language-plaintext highlighter-rouge">FLOPS=FLOP/exec_time=(2*m*n*k)/exec_time</code>.</p>

<p><img src="/assets/matmul_cpu/matmul_naive.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<h2 id="3-theoretical-limit">3. Theoretical Limit</h2>

<p>The image below shows a simplified model of the computer’s memory hierarchy (for now, ignore the layers between the registers and the main memory(=RAM); we will discuss them later).</p>

<p><img src="/assets/matmul_cpu/mem_system_nc.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>To perform arithmetic operations on data stored in RAM (off-chip memory, slow and large capacity), the data must be first transferred to CPU and placed in CPU registers (on-chip memory, fast and small capacity). Modern x86-64 CPUs support SIMD (Single Instruction Multiple Data) extensions, which allow multiple pieces of data to be processed in parallel. There are various SIMD extensions, but the ones relevant to our discussion are <a href="https://en.wikipedia.org/wiki/Advanced_Vector_Extensions">Advanced Vector Extensions</a> (AVX2) and <a href="https://en.wikipedia.org/wiki/FMA_instruction_set">Fused Multiply-Add</a> (FMA). Both AVX2 and FMA operate on data stored in special 256-bit <code class="language-plaintext highlighter-rouge">YMM</code> registers. Each <code class="language-plaintext highlighter-rouge">YMM</code> register can hold 8 packed single-precision (32-bit) floats. The FMA2 instructions perform element-wise multiply-add operation on data stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers. The corresponding assembly instruction is called <code class="language-plaintext highlighter-rouge">VFMADD231PS</code> (PS stands for PackedSingle) and takes three vector registers (<code class="language-plaintext highlighter-rouge">YMM1</code>, <code class="language-plaintext highlighter-rouge">YMM2</code>, <code class="language-plaintext highlighter-rouge">YMM3</code>) as input to compute <code class="language-plaintext highlighter-rouge">YMM1 = YMM2 * YMM3 + YMM1</code>.</p>

<p><img src="/assets/matmul_cpu/fmadd.png" alt="" width="60%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>According to the <a href="https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html">intel intrinsics guide</a> or <a href="https://uops.info/table.html">https://uops.info/table.html</a>, for my CPU the throughput (TP) of the fused-multiply-add instruction is 0.5 cycles/instruction or with other words 2 instructions/cycle:
<img src="/assets/matmul_cpu/fmadd_uops.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Theoretically, Ryzen 9700X can perform 32 FLOP per cycle: 8 (floats in <code class="language-plaintext highlighter-rouge">YMM</code> register) * 2 (add + mul) * 2 (1/TP). Therefore, the theoretical peak FLOPS in single-threaded mode can be roughly estimated as <code class="language-plaintext highlighter-rouge">CPU_CLOCK_SPEED * 32</code> or <code class="language-plaintext highlighter-rouge">n_cores * CPU_CLOCK_SPEED * 32</code> in multi-threaded mode. For example, assuming a sustainable clock speed of 4.7 GHz for an 8-core 9700X processor, the theoretical peak FLOPS in a multi-threaded setting would be 1203 FLOPS.</p>

<h2 id="4-naive-implementation">4. Naive Implementation</h2>

<p>In this tutorial we assume that matrices are stored in column-major order: e.g. matrix <code class="language-plaintext highlighter-rouge">A</code> of shape <code class="language-plaintext highlighter-rouge">MxN</code> is stored as contiguous array of length <code class="language-plaintext highlighter-rouge">M*N</code> and an element <code class="language-plaintext highlighter-rouge">A[row][col]</code> is accessed via C raw pointer <code class="language-plaintext highlighter-rouge">ptr[col*M + row]</code>, where <code class="language-plaintext highlighter-rouge">0 &lt;= col &lt;= N-1</code> and <code class="language-plaintext highlighter-rouge">0 &lt;= row &lt;= M-1</code>.
<img src="/assets/matmul_cpu/mem_layout.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The simplest implementation of $C=AB$ can be described as follows:
<img src="/assets/matmul_cpu/matmul_naive.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">matmul_naive</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
      <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">p</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">p</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">p</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">C</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">A</span><span class="p">[</span><span class="n">p</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span> <span class="o">+</span> <span class="n">p</span><span class="p">];</span>
      <span class="p">}</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>
<p>Here, we iterate over all rows (the outermost loop) and all columns (the second loop) of <code class="language-plaintext highlighter-rouge">C</code> and for each element of <code class="language-plaintext highlighter-rouge">C</code> we calculate the dot product (the innermost loop) between the corresponding row of matrix <code class="language-plaintext highlighter-rouge">A</code> and column of matrix <code class="language-plaintext highlighter-rouge">B</code>. It’s always good to start with a simple and robust algorithm that can later be used to test optimized implementations.</p>

<h2 id="5-kernel">5. Kernel</h2>

<p>The key idea of high-performance matrix multiplication on CPU is to develop a function that efficiently computes a sub-matrix of $C$. Then, by iterating over $C$ and applying this function to all non-overlapping sub-matrices, we can significantly speed up the entire matrix multiplication operation. For this, we, first, partition the matrix $C$ of shape $M \times N$ into smaller non-overlapping sub-matrices of shape $m_R \times n_R$, with $n_R \ll N$ and $m_R \ll M$. To calculate $C=AB$, we iterate over $C$ and compute each of its non-overlapping $m_R \times n_R$ sub-matrices as shown below:</p>

<p><img src="/assets/matmul_cpu/matmul_kernel.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The function that computes an $m_R \times n_R$ sub-matrix $\bar{C}$ of $C$ is called a <strong>kernel</strong> (aka. <strong>micro-kernel</strong> using BLIS notation). This function is the core of high-performance matrix multiplication. When we say a matrix multiplication algorithm is optimized for a specific CPU architecture, it usually refers to kernel optimization. For example, OpenBLAS contains <a href="https://github.com/OpenMathLib/OpenBLAS/tree/develop/kernel">kernels</a> optimized for different CPU microarchitectures.</p>

<p>Let’s take a closer look at the kernel. To compute an $m_R \times n_R$ sub-matrix $\bar{C}$ of $C$, we need to multiply corresponding $m_R \times K$ sub-matrix $\bar{A}$ of $A$ with $K \times n_R$ sub-matrix $\bar{B}$ of $B$ as shown in the figure below:</p>

<p><img src="/assets/matmul_cpu/kernel.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>If we were to do this in a naive manner using the dot product, we would need to fetch $2K$ elements from RAM to calculate a single element of $\bar{C}$ or $2K m_R n_R$ elements in total to compute $\bar{C}$. There is, however, an alternative strategy that can reduce the number of fetched elements.</p>

<p>First, we initialize the matrix $\bar{C}$ with zeros and store it in registers. Since both $n_R$ and $m_R$ are small, the entire matrix fits within the registers. Here, the subscript $R$ in $n_R$ and $m_R$ denotes “register”. Next, we iterate over the dimension $K$, and in each iteration, we:</p>

<ol>
  <li>load 1 column of $\bar{A}$ and 1 row of $\bar{B}$ from RAM into the registers. Again, note that both the row and column vectors are limited in size and can be stored in the registers.</li>
  <li>compute the outer product between the two vectors and add the result of the outer product to the matrix $\bar{C}$.</li>
</ol>

<p><img src="/assets/matmul_cpu/kernel_rank.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>After $K$ iterations, the computation of the matrix $\bar{C}$ is completed and it can be stored into RAM. $\bar{C}$ is often referred to as the <em>accumulator</em>, because it accumulates the outer products along the dimension $K$. A single accumulation step of the outer product between two vectors is also known as <strong>rank-1 update</strong>.</p>

<blockquote>
  <p>Outer product between a column vector and a row vector.
<img src="/assets/matmul_cpu/outer_product.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>
</blockquote>

<p>In total, we fetch $(m_R + n_R)K$ elements from RAM into registers. Compared to the naive approach, this reduces the number of memory accesses by a factor of</p>

\[\frac{2m_Rn_RK}{(m_R + n_R)K} = \frac{2m_Rn_R}{m_R + n_R}\]

<p>This factor is maximized when both $m_R$, $n_R$ are large and equal. However, the values of $m_R$ and $n_R$ are typically constrained by the available register memory.</p>

<p>Now, let’s discuss in detail how the outer product and accumulation can be efficiently implemented using SIMD FMA instructions. Unfortunately, there are no SIMD instructions that compute the outer product in a single step. Therefore, we need to decompose the outer product into simpler operations. The figure below illustrates the process:</p>

<p><img src="/assets/matmul_cpu/kernel_registers.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Here, we compute the outer product between a column vector of size $m_R$ and a row vector of size $n_R$ to update an accumulator $\bar{C}$ of size $m_R \times n_R$. The accumulator is stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers, with each column of the accumulator spanning one or multiple <code class="language-plaintext highlighter-rouge">YMM</code> registers. The column vector is also stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers (highlighted as yellow). Since each <code class="language-plaintext highlighter-rouge">YMM</code> register holds 8 floats, the dimension $m_R$ must be divisible by 8. The accumulator is updated column by column. During the first iteration we broadcast the first element of the row vector to a vector of size $m_R$ and place it in the <code class="language-plaintext highlighter-rouge">YMM</code> registers (highlighted as green). Then, we element-wise multiply the column vector with the broadcasted vector and accumulate the result to the first column of the accumulator $\bar{C}$ using FMA instruction. We repeat this process for the remaining elements of the row vector to update the corresponding columns of the accumulator. After $n_R$ iterations, the rank-1 update of the accumulator is completed.</p>

<p>The last thing we need to discuss before implementing the kernel in C is how to choose the kernel size i.e. $m_R$ and $n_R$. CPUs with AVX support have <strong>16 YMM registers</strong>. From our previous discussion, we know that we need $(m_R/8) \cdot n_R$ registers to store the accumulator $\bar{C}$, $m_R/8$ registers to store the column vector and 1 register (because we can reuse the same register for all FMA operations) for the broadcasted vector. We want $m_R$ and $n_R$ to be as large as possible while satisfying the following conditions:</p>

<ul>
  <li>$\Big(\cfrac{m_R}{8} \cdot n_R + \cfrac{m_R}{8} + 1\Big) &lt;= 16$</li>
  <li>$m_R$ is a multiple of 8</li>
</ul>

<p>In theory we want $m_R = n_R$ to minimize the number of fetched elements. However, in practice, a non-square kernel with $m_R = 16, n_R = 6$ showed the best performance on my CPU. Therefore, we will implement this kernel in the next section. Feel free to experiment with other kernel sizes, such as $8 \times 8, 8 \times 12$, $8 \times 13$, $8 \times 14$, $32 \times 2$ and compare their performance on your CPU.</p>

<p>Let’s implement the algorithm discussed above using the $16 \times 6$ kernel. The code of this implementation can be found at <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_kernel.h">matmul_kernel.c</a>. To use SIMD instructions in C we first need to include the <code class="language-plaintext highlighter-rouge">immintin.h</code> library:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;immintrin.h&gt;</span><span class="cp">
</span></code></pre></div></div>

<p>The implementation of the algorithm is straightforward: we iterate over matrix $C$ and apply the kernel function to each of it’s non-overlapped $16 \times 6$ sub-matrices $\bar{C}$.</p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">matmul_kernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="mi">16</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">j</span> <span class="o">+=</span> <span class="mi">6</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">kernel_16x6</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span><span class="p">],</span> <span class="o">&amp;</span><span class="n">C</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">],</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">);</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The kernel function is declared as follows:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">kernel_16x6</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A_start</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B_start</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C_start</span><span class="p">,</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="kt">int</span> <span class="n">K</span><span class="p">);</span>
</code></pre></div></div>

<p>The function takes as input pointers to the starting positions of $\bar{A}, \bar{B}$, and $\bar{C}$ along with the matrix problem size. It then computes $16 \times 6$ sub-matrix $\bar{C}$ of $C$ according to $\bar{C} = \bar{A} \bar{B}$.</p>

<p>Inside the kernel function, first, we declare the variables stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__m256</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">6</span><span class="p">][</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="p">{};</span> <span class="c1">// zero-initialized</span>
<span class="n">__m256</span> <span class="n">b_packFloat8</span><span class="p">;</span>
<span class="n">__m256</span> <span class="n">a0_packFloat8</span><span class="p">;</span>
<span class="n">__m256</span> <span class="n">a1_packFloat8</span><span class="p">;</span>
</code></pre></div></div>

<p>A variable of type <code class="language-plaintext highlighter-rouge">__m256</code> is a 256-bit vector that represents the contents of a <code class="language-plaintext highlighter-rouge">YMM</code> register, which holds eight 32-bit floating-point values. <code class="language-plaintext highlighter-rouge">C_accum</code> is the accumulator stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers. The variable <code class="language-plaintext highlighter-rouge">b_packFloat8</code> contains a broadcasted element from a row vector of $\bar{B}$, while <code class="language-plaintext highlighter-rouge">a0_packFloat8</code> and <code class="language-plaintext highlighter-rouge">a1_packFloat8</code> represent a column vector of $\bar{A}$. Since the column vector contains 16 floats, it requires two <code class="language-plaintext highlighter-rouge">YMM</code> registers for storage.</p>

<p>SIMD intrinsics are well documented and can be found in the <a href="https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html">Intel Intrinsics Guide</a>. For example, <a href="https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_loadu_ps&amp;ig_expand=4100">_mm256_loadu_ps</a></p>

<p><img src="/assets/matmul_cpu/mm256_loadu.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The kernel iterates over the dimension $K$ and in each iteration performs a rank-1 update of the accumulator:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">p</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">p</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">p</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
  <span class="c1">// Load column vector of size 16</span>
  <span class="c1">// {</span>
  <span class="n">a0_packFloat8</span> <span class="o">=</span> <span class="n">_mm256_loadu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A_start</span><span class="p">[</span><span class="n">p</span> <span class="o">*</span> <span class="n">M</span><span class="p">]);</span>
  <span class="n">a1_packFloat8</span> <span class="o">=</span> <span class="n">_mm256_loadu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A_start</span><span class="p">[</span><span class="n">p</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="mi">8</span><span class="p">]);</span>
  <span class="c1">// }</span>
  <span class="c1">// Broadcast scalar element to vector of size 8</span>
  <span class="c1">// {</span>
  <span class="n">b_packFloat8</span> <span class="o">=</span> <span class="n">_mm256_broadcast_ss</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B_start</span><span class="p">[</span><span class="n">p</span><span class="p">]);</span>
  <span class="c1">// }</span>
  <span class="c1">// Update the first column of the accumulator</span>
  <span class="c1">// {</span>
  <span class="n">C_accum</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_fmadd_ps</span><span class="p">(</span><span class="n">a0_packFloat8</span><span class="p">,</span> <span class="n">b_packFloat8</span><span class="p">,</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
  <span class="n">C_accum</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_fmadd_ps</span><span class="p">(</span><span class="n">a1_packFloat8</span><span class="p">,</span> <span class="n">b_packFloat8</span><span class="p">,</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
  <span class="c1">// }</span>
  <span class="p">...</span>
  <span class="p">...</span>
  <span class="p">...</span>
  <span class="n">b_packFloat8</span> <span class="o">=</span> <span class="n">_mm256_broadcast_ss</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B_start</span><span class="p">[</span><span class="mi">5</span> <span class="o">*</span> <span class="n">K</span> <span class="o">+</span> <span class="n">p</span><span class="p">]);</span>
  <span class="c1">// update the last column of the accumulator</span>
  <span class="c1">// {</span>
  <span class="n">C_accum</span><span class="p">[</span><span class="mi">5</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_fmadd_ps</span><span class="p">(</span><span class="n">a0_packFloat8</span><span class="p">,</span> <span class="n">b_packFloat8</span><span class="p">,</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">5</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
  <span class="n">C_accum</span><span class="p">[</span><span class="mi">5</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_fmadd_ps</span><span class="p">(</span><span class="n">a1_packFloat8</span><span class="p">,</span> <span class="n">b_packFloat8</span><span class="p">,</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">5</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
  <span class="c1">// }</span>
<span class="p">}</span>
</code></pre></div></div>

<p>After $K$ rank-1 updates, the computation of the accumulator is complete, and the result can be stored in RAM:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Store the accumulator column by column:</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="mi">6</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
  <span class="n">_mm256_storeu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
  <span class="n">_mm256_storeu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="mi">8</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Let’s take a look at the generated assembly code to see if it actually contains SIMD FMA instructions and uses the <code class="language-plaintext highlighter-rouge">YMM</code> registers:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>gcc <span class="nt">-O3</span> <span class="nt">-mno-avx512f</span> <span class="nt">-march</span><span class="o">=</span>native matmul_kernel.c <span class="nt">-S</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>// matmul_kernel.s
...
vfmadd231ps	%ymm14, %ymm1, %ymm13
vfmadd231ps	%ymm14, %ymm0, %ymm12
vmovaps	%ymm13, 32(%rsp)
vmovaps	%ymm12, 64(%rsp)
vbroadcastss	(%rax,%r9), %ymm14
vfmadd231ps	%ymm14, %ymm1, %ymm10
vfmadd231ps	%ymm14, %ymm0, %ymm11
vmovaps	%ymm10, 96(%rsp)
vmovaps	%ymm11, 128(%rsp)
vbroadcastss	(%rax,%r9,2), %ymm14
addq	$4, %rax
vfmadd231ps	%ymm14, %ymm1, %ymm2
vfmadd231ps	%ymm14, %ymm0, %ymm3
...
</code></pre></div></div>

<h2 id="6-padding">6. Padding</h2>

<p>You may have noticed that the current implementation only works for matrix sizes where $M$ and $N$ are multiples of $m_R$ and $n_R$, respectively. Specifically, the kernel assumes that matrix $\bar{C}$ has dimensions $m_R \times n_R$, matrix $\bar{A}$ is $m_R \times K$ and matrix $\bar{B}$ is $K \times n_R$. Our goal is to generalize the kernel so that it can handle matrices $\bar{C}, \bar{A}, \bar{B}$ with dimensions $m \times n, m \times K, K \times n$, even when $m \neq m_R$ and $n \neq n_R$, as shown below:</p>

<p><img src="/assets/matmul_cpu/kernel_mask.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>First, when storing the accumulator, we need to ensure that elements are only stored within the matrix boundaries. If the number of overlapping columns, $n$, is smaller than $n_R$, the process is straightforward - we simply iterate over $n$ columns instead of​ $n_R$:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// n - number of overlapped columns within C boundary</span>

<span class="c1">// "j &lt; n" instead "j &lt; 6", since n can be less than 6.</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
  <span class="n">_mm256_storeu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
  <span class="n">_mm256_storeu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="mi">8</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The case where the number of overlapped rows $m$ differs from $m_R$ is a bit trickier because <code class="language-plaintext highlighter-rouge">_mm256_storeu_ps</code> stores 8 elements at once. Fortunately, <code class="language-plaintext highlighter-rouge">immintrin.h</code> library contains <code class="language-plaintext highlighter-rouge">_mm256_maskstore_ps</code> function, which stores packed floats according to mask values. The function takes <a href="https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-10/mm256-maskstore-ps-mm-maskstore-ps.html">three arguments</a> as input:</p>

<ol>
  <li><code class="language-plaintext highlighter-rouge">float *a</code></li>
  <li><code class="language-plaintext highlighter-rouge">__m256i mask</code></li>
  <li><code class="language-plaintext highlighter-rouge">__m256 b</code></li>
</ol>

<p><code class="language-plaintext highlighter-rouge">__m256i</code> is a vector datatype that holds eight 32-bit integers. Each integer in <code class="language-plaintext highlighter-rouge">mask</code> corresponds to a data element in <code class="language-plaintext highlighter-rouge">b</code>. The most significant bit (MSB) of each integer in <code class="language-plaintext highlighter-rouge">mask</code> represents the mask bit. If the mask bit is zero, the corresponding value in <code class="language-plaintext highlighter-rouge">b</code> is not stored in the memory location pointed to by <code class="language-plaintext highlighter-rouge">a</code>. For example, the MSB of unsigned integer <code class="language-plaintext highlighter-rouge">2147483648</code> (binary format <code class="language-plaintext highlighter-rouge">10000000 00000000 00000000 00000000</code>) is <code class="language-plaintext highlighter-rouge">1</code>, so the corresponding data element in <code class="language-plaintext highlighter-rouge">b</code> will be stored. On the other hand, the MSB of unsigned integer <code class="language-plaintext highlighter-rouge">2147483647</code> (binary format <code class="language-plaintext highlighter-rouge">01111111 11111111 11111111 11111111</code>) is <code class="language-plaintext highlighter-rouge">0</code>, meaning the corresponding data element in <code class="language-plaintext highlighter-rouge">b</code> will not be stored.</p>

<p>If $m \neq m_R$ , we generate integer masks by left-shifting unsigned integer <code class="language-plaintext highlighter-rouge">65535</code> (=<code class="language-plaintext highlighter-rouge">00000000 00000000 11111111 111111111</code> in binary format) depending on the number of overlapped rows $m$. In the code snippet below the function <code class="language-plaintext highlighter-rouge">_mm256_setr_epi32()</code> creates a <code class="language-plaintext highlighter-rouge">__m256i</code> vector from eight 32-bit integers.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__m256i</span> <span class="n">masks</span><span class="p">[</span><span class="mi">2</span><span class="p">];</span>
<span class="k">if</span> <span class="p">(</span><span class="n">m</span> <span class="o">!=</span> <span class="mi">16</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">const</span> <span class="kt">uint32_t</span> <span class="n">bit_mask</span> <span class="o">=</span> <span class="mi">65535</span><span class="p">;</span>
  <span class="n">masks</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_setr_epi32</span><span class="p">(</span><span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">15</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">14</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">13</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">12</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">11</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">10</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">9</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">8</span><span class="p">));</span>
  <span class="n">masks</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_setr_epi32</span><span class="p">(</span><span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">7</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">6</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">5</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">4</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">3</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">2</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="n">m</span><span class="p">);</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">_mm256_maskstore_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span><span class="p">],</span> <span class="n">masks</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
    <span class="n">_mm256_maskstore_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="mi">8</span><span class="p">],</span> <span class="n">masks</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The compiler auto-vectorizes the sequential bit-shifting operations using a combination of <code class="language-plaintext highlighter-rouge">vpaddd</code> and <code class="language-plaintext highlighter-rouge">vpsllvd</code> instructions, making the mask computation very efficient. There is, however, an alternative method to compute the masks, as will be shown later.</p>

<p>When loading elements from matrices $\bar{A}$ and $\bar{B}$ inside the kernel, we need to check that the loads are within the matrix boundaries. One way to do this is by using <code class="language-plaintext highlighter-rouge">_mm256_maskload_ps</code> when loading elements from the matrix $\bar{A}$ and looping over $n$ elements instead of $n_R$ when loading elements from the matrix $\bar{B}$. However, this method would significantly degrade the kernel’s performance. The additional instructions required to compute the loading masks introduce overhead, and since $n$ is not a compile-time constant, the compiler cannot unroll the loop efficiently. Instead, if $m \neq m_R$, we copy the matrix $\bar{A}$ into a buffer, pad it with zeros and pass the padded matrix of size $m_R \times K$ to the kernel. We do the same for the matrix $\bar{B}$ if $n \neq n_R$. The implementation straightforwardly follows the description:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#define BLOCK_A_MAXSIZE 500000
#define BLOCK_B_MAXSIZE 200000
</span>
<span class="k">static</span> <span class="kt">float</span> <span class="n">blockA_buffer</span><span class="p">[</span><span class="n">BLOCK_A_MAXSIZE</span><span class="p">]</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">aligned</span><span class="p">(</span><span class="mi">64</span><span class="p">)));</span>
<span class="k">static</span> <span class="kt">float</span> <span class="n">blockB_buffer</span><span class="p">[</span><span class="n">BLOCK_B_MAXSIZE</span><span class="p">]</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">aligned</span><span class="p">(</span><span class="mi">64</span><span class="p">)));</span>

<span class="kt">void</span> <span class="nf">matmul_pack</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="mi">16</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">const</span> <span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="n">M</span> <span class="o">-</span> <span class="n">i</span><span class="p">);</span>
        <span class="kt">float</span><span class="o">*</span> <span class="n">blockA</span> <span class="o">=</span> <span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
        <span class="kt">int</span> <span class="n">blockA_ld</span> <span class="o">=</span> <span class="n">M</span><span class="p">;</span>
        <span class="k">if</span> <span class="p">(</span><span class="n">m</span> <span class="o">!=</span> <span class="mi">16</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">pack_blockA</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">blockA_buffer</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">);</span>
            <span class="n">blockA</span> <span class="o">=</span> <span class="n">blockA_buffer</span><span class="p">;</span>
            <span class="n">blockA_ld</span> <span class="o">=</span> <span class="mi">16</span><span class="p">;</span>
        <span class="p">}</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">j</span> <span class="o">+=</span> <span class="mi">6</span><span class="p">)</span> <span class="p">{</span>
            <span class="k">const</span> <span class="kt">int</span> <span class="n">n</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="n">N</span> <span class="o">-</span> <span class="n">j</span><span class="p">);</span>
            <span class="kt">float</span><span class="o">*</span> <span class="n">blockB</span> <span class="o">=</span> <span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span><span class="p">];</span>
            <span class="k">if</span> <span class="p">(</span><span class="n">n</span> <span class="o">!=</span> <span class="mi">6</span><span class="p">)</span> <span class="p">{</span>
                <span class="n">pack_blockB</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span><span class="p">],</span> <span class="n">blockB_buffer</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">);</span>
                <span class="n">blockB</span> <span class="o">=</span> <span class="n">blockB_buffer</span><span class="p">;</span>
            <span class="p">}</span>
            <span class="n">kernel_16x6</span><span class="p">(</span><span class="n">blockA</span><span class="p">,</span> <span class="n">blockB</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">C</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">],</span> <span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">blockA_ld</span><span class="p">);</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>For further implementations details, please check <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_pad.h">matmul_pad.h</a></p>

<h2 id="7-cache-blocking">7. Cache Blocking</h2>

<p>Let’s revisit the computer’s memory hierarchy. Previously, we focused on the main memory (DRAM) and the CPU registers, but we skipped an important intermediary: the CPU cache system.</p>

<p><img src="/assets/matmul_cpu/mem_system.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Unlike DRAM, the CPU cache is an on-chip memory designed to store frequently and/or recently accessed data from the main memory. This helps minimize data transfers between the main memory and CPU registers. Although the cache is much faster than DRAM, it has a limited storage capacity. To optimize data access, modern desktop CPUs use a multi-level cache hierarchy. This typically includes L1, L2, and L3 caches, each offering progressively larger storage but with increasing access times. L1 cache is the fastest and closest to the CPU core.</p>

<p><img src="/assets/matmul_cpu/cpu_arch.png" alt="" /></p>

<p><img src="/assets/matmul_cpu/core_arch.png" alt="" /></p>

<p style="display:block; margin-left:auto; margin-right:auto; text-align: center"><em>Intel Core i9-13900K labelled die shot. Source: <a href="https://www.youtube.com/watch?v=dX9CGRZwD-w">How are Microchips Made?</a></em></p>

<p>To improve access speed, CPUs transfer data between main memory and cache in fixed-size chunks called <strong>cache lines</strong> or <strong>cache blocks</strong>. When a cache line is loaded from main memory, it is stored as a cache entry. For example, in AMD Ryzen Zen CPUs, the cache line size is <a href="https://en.wikichip.org/wiki/amd/microarchitectures/zen_4#Memory_Hierarchy">64 bytes</a>. The cache takes advantage of data locality - how programs typically access memory. When a single floating-point number is requested from a continuous array in memory, the cache doesn’t just fetch that one value; it also preloads the next floating-point numbers and stores them in the cache. This is why reading data sequentially from an array is much more efficient than randomly accessing scattered memory locations. When the CPU needs to read or write to a memory location, it first checks if the data is already in the cache. This leads to two possible scenarios:</p>

<ol>
  <li><strong>Cache Hit</strong> - If the requested memory location is found in the cache, the CPU can access it instantly, avoiding the need to fetch data from the much slower DRAM.</li>
  <li><strong>Cache Miss</strong> - If the requested data is not in the cache, the CPU retrieves it from the main memory and stores it in the cache for future access.</li>
</ol>

<p>Since the cache has limited space, it must decide which data to replace when new information needs to be stored. This decision is governed by a <a href="https://en.wikipedia.org/wiki/Cache_replacement_policies">cache replacement policy</a>. Some of the most common policies include:</p>

<ol>
  <li><strong>LRU</strong> (Least Recently Used): Replaces the cache entry that has gone unused the longest.</li>
  <li><strong>LFU</strong> (Least Frequently Used): Evicts the entry that has been accessed the least often.</li>
  <li><strong>LFRU</strong> (Least Frequently Recently Used): A hybrid approach that considers both recent and overall access frequency.</li>
</ol>

<p>Similar to registers, once data is loaded into the cache, we want to reuse the data as much as possible to reduce main memory accesses. Given the cache’s limited capacity, storing entire input matrices $C, B, A$  in the cache isn’t feasible. Instead, we divide them into smaller blocks, load these blocks into the cache, and reuse them for rank-1 updates. This technique is often referred to as <strong>tiling</strong> or <strong>cache blocking</strong>, allowing us to handle matrices of arbitrary size effectively.</p>

<p>The single-threaded matrix multiplication with cache blocking can be visualized as shown in the image borrowed from the official <a href="https://github.com/flame/blis/blob/master/docs/Multithreading.md">BLIS repository</a>:</p>

<p><img src="/assets/matmul_cpu/blis_design.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Let’s step through the diagram and discuss it.
In the outer-most loop (5th loop) we iterate over dimension $N$, dividing matrix $C$ into blocks $C_j$ of size $M \times n_c$  and matrix $B$  into blocks $B_j$ of size $K \times n_c$. The subscript $c$ in $n_c$ stands for <em>cache</em>.
In the 4th loop we iterate over dimension $K$ and divide matrix $A$ into $A_j$ of size $M \times k_c$  and $B_j$ into $B_p$ of size $k_c \times n_c$. Notice $B_p$ has fixed, limited size and can now be loaded into the cache. $B_p$ is packed into $\tilde{B}_p$, padded with zeros, if necessary, and loaded into the L3 cache. I
In the 3rd loop we iterate over dimension $M$ and divide $C_j$ into $C_i$ (there is a typo in the diagram) of size $m_c \times n_c$ and $A_p$  into $A_j$ of size $m_c \times k_c$. Matrix $A_j$ is now restricted in size and can be loaded entirely into the L2 cache. $A_j$ is packed into $\tilde{A}_j$ and padded with zeros if needed. Note how we reuse the same $\tilde{B}_p$ block from the L3 cache for different $A_j$ blocks. Both $m_c$ and $n_c$ are chosen to be a multiple of $m_R$ and $n_R$ respectively.</p>

<p>In the last two loops we simply iterate over cached blocks and divide them into $m_R \times k_c$ and $k_c \times n_R$ panels. These panels are then passed to the kernel to perform rank-1 updates on the $m_R \times n_R$ sub-matrix of $C$, similarly to what we have already done in the previous chapter. Each panel of $\tilde{B}_p$ is loaded into the L1 cache and reused for multiple panels of $\tilde{A}_j$.
Keep in mind that $\tilde{A}_j$ and $\tilde{B}_p$ are packed differently. During rank-1 updates we sequentially read a panel of $\tilde{A}_j$ column by column and a panel of $\tilde{B}_p$ row by row. Thus,  each panel inside $\tilde{A}_j$ is stored in column-major order, while each panel inside $\tilde{B}_p$ is stored in row-major order.</p>

<p>Different CPU models have different cache sizes. To achieve peak performance, it’s crucial to optimize three key parameters: cache sizes for L1, L2, and L3 cashes (represented by $k_c$​, $m_c$​, and $n_c$​ respectively). Theoretically, these parameters should be chosen so that:</p>

<ul>
  <li>$k_c​ \times n_c$​ fills the entire L3 cache.</li>
  <li>$m_c​ \times k_c​$ fills the entire L2 cache.</li>
  <li>$k_c​ \times n_R$​ fills the entire L1 cache.</li>
</ul>

<p>While these values provide a good starting point, using larger values often leads to better performance in practice. Unfortunately (or fortunately), we cannot manually place data into the cache or control which cache levels store the data; the CPU manages this automatically using cache replacement policies. Therefore, cache blocking and cache reuse must be implemented at the algorithm level through, for example, well-designed loops and strategic data access patterns.</p>

<p>The implementation <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_cache.h">matmul_cache.h</a> straightforwardly follows the algorithm depicted in the diagram:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">matmul_cache</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">j</span> <span class="o">+=</span> <span class="n">NC</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">const</span> <span class="kt">int</span> <span class="n">nc</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">NC</span><span class="p">,</span> <span class="n">N</span> <span class="o">-</span> <span class="n">j</span><span class="p">);</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">p</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">p</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">p</span> <span class="o">+=</span> <span class="n">KC</span><span class="p">)</span> <span class="p">{</span>
      <span class="k">const</span> <span class="kt">int</span> <span class="n">kc</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">KC</span><span class="p">,</span> <span class="n">K</span> <span class="o">-</span> <span class="n">p</span><span class="p">);</span>
      <span class="n">pack_blockB</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span> <span class="o">+</span> <span class="n">p</span><span class="p">],</span> <span class="n">blockB_packed</span><span class="p">,</span> <span class="n">nc</span><span class="p">,</span> <span class="n">kc</span><span class="p">,</span> <span class="n">K</span><span class="p">);</span>
      <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="n">MC</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">const</span> <span class="kt">int</span> <span class="n">mc</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">MC</span><span class="p">,</span> <span class="n">M</span> <span class="o">-</span> <span class="n">i</span><span class="p">);</span>
        <span class="n">pack_blockA</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">p</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">],</span> <span class="n">blockA_packed</span><span class="p">,</span> <span class="n">mc</span><span class="p">,</span> <span class="n">kc</span><span class="p">,</span> <span class="n">M</span><span class="p">);</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">jr</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">jr</span> <span class="o">&lt;</span> <span class="n">nc</span><span class="p">;</span> <span class="n">jr</span> <span class="o">+=</span> <span class="n">NR</span><span class="p">)</span> <span class="p">{</span>
          <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">ir</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">ir</span> <span class="o">&lt;</span> <span class="n">mc</span><span class="p">;</span> <span class="n">ir</span> <span class="o">+=</span> <span class="n">MR</span><span class="p">)</span> <span class="p">{</span>
            <span class="k">const</span> <span class="kt">int</span> <span class="n">mr</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">MR</span><span class="p">,</span> <span class="n">mc</span> <span class="o">-</span> <span class="n">ir</span><span class="p">);</span>
            <span class="k">const</span> <span class="kt">int</span> <span class="n">nr</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">NR</span><span class="p">,</span> <span class="n">nc</span> <span class="o">-</span> <span class="n">jr</span><span class="p">);</span>
            <span class="n">kernel_16x6</span><span class="p">(</span><span class="o">&amp;</span><span class="n">blockA_packed</span><span class="p">[</span><span class="n">ir</span> <span class="o">*</span> <span class="n">kc</span><span class="p">],</span> <span class="o">&amp;</span><span class="n">blockB_packed</span><span class="p">[</span><span class="n">jr</span> <span class="o">*</span> <span class="n">kc</span><span class="p">],</span> <span class="o">&amp;</span><span class="n">C</span><span class="p">[(</span><span class="n">j</span> <span class="o">+</span> <span class="n">jr</span><span class="p">)</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="n">ir</span><span class="p">)],</span> <span class="n">mr</span><span class="p">,</span> <span class="n">nr</span><span class="p">,</span> <span class="n">kc</span><span class="p">,</span> <span class="n">M</span><span class="p">);</span>
          <span class="p">}</span>
        <span class="p">}</span>
      <span class="p">}</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="8-kernel-micro-optimizations">8. Kernel Micro-Optimizations</h2>

<p>Instead of using arrays of <code class="language-plaintext highlighter-rouge">__m256</code> to define the accumulator $\bar{C}$ and the masks</p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__m256</span> <span class="n">C_buffer</span><span class="p">[</span><span class="mi">6</span><span class="p">][</span><span class="mi">2</span><span class="p">];</span>
<span class="n">__m256i</span> <span class="n">masks</span><span class="p">[</span><span class="mi">2</span><span class="p">];</span>
</code></pre></div></div>
<p>we explicitly unroll them</p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="n">__m256</span> <span class="n">C00</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C10</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C01</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C11</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C02</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C12</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C03</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C13</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C04</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C14</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C05</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C15</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256i</span> <span class="n">packed_mask0</span><span class="p">;</span>
    <span class="n">__m256i</span> <span class="n">packed_mask1</span><span class="p">;</span>
</code></pre></div></div>
<p>By doing this, GCC can better optimize the code avoiding register spilling. Additionally, we use vector instructions to calculate the masks as follows:</p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">static</span> <span class="kt">int8_t</span> <span class="n">mask</span><span class="p">[</span><span class="mi">32</span><span class="p">]</span>
    <span class="n">__attribute__</span><span class="p">((</span><span class="n">aligned</span><span class="p">(</span><span class="mi">64</span><span class="p">)))</span> <span class="o">=</span> <span class="p">{</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
                                    <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">};</span>
<span class="n">packed_mask0</span> <span class="o">=</span> <span class="n">_mm256_cvtepi8_epi32</span><span class="p">(</span><span class="n">_mm_loadu_si64</span><span class="p">(</span><span class="o">&amp;</span><span class="n">mask</span><span class="p">[</span><span class="mi">16</span> <span class="o">-</span> <span class="n">mr</span><span class="p">]));</span>
<span class="n">packed_mask1</span> <span class="o">=</span> <span class="n">_mm256_cvtepi8_epi32</span><span class="p">(</span><span class="n">_mm_loadu_si64</span><span class="p">(</span><span class="o">&amp;</span><span class="n">mask</span><span class="p">[</span><span class="mi">16</span> <span class="o">-</span> <span class="n">mr</span> <span class="o">+</span> <span class="mi">8</span><span class="p">]));</span>
</code></pre></div></div>

<p>The corresponding implementation can be found at <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_micro.h">matmul_micro.h</a></p>

<h2 id="9-multithreading">9. Multithreading</h2>

<p>There are indeed many loops that can be potentially parallelized. To achieve high-performance, we want to parallelize both packing and arithmetic operations. Let’s start with the arithmetic operations. The 5th, 4th, 3rd loops around the micro-kernel iterate over matrix dimensions in chunks of cache block sizes $n_c$, $k_c$, $m_c$. To efficiently parallelize the loops and keep all threads busy, we want number of iterations (=matrix dimension / cache block size) to be at least = number of threads (generally, the more the better). In other words, the input matrix dimension should be at least = number of threads  * cache block size. As we discussed earlier, we also want cache blocks to fully occupy the corresponding cache levels. On modern CPUs, the second requirement results in cache block sizes of thousand(s) of elements. For example, on my Ryzen 9700X, cache block sizes of $n_c=1535$, $m_c=1024$ attain the best performance in the single-threaded scenario. Given the number of available cores on Ryzen 9700X, we need input matrices with dimensions of at least $\max(m_c, n_c) \times \text{number of cores} = 1535 \times 8 = 12280$ to be able to distribute the work over all cores.</p>

<p><img src="/assets/matmul_cpu/blis_design.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>In contrast, the last two loops iterate over cache blocks, dividing them into $m_R, n_R$ blocks. Since $n_R, m_R$ are typically very small (&lt;20), these loops are ideal candidates for parallelization. Moreover, we can choose $m_c, n_c$ to be multiples of number of cores so that the work is evenly distributed across all cores.</p>

<p>On my machine, parallelizing the second and first inner loops jointly with <code class="language-plaintext highlighter-rouge">collapse(2)</code> results in the best performance:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#pragma omp parallel for collapse(2) num_threads(NTHREADS)
</span>  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">jr</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">jr</span> <span class="o">&lt;</span> <span class="n">nc</span><span class="p">;</span> <span class="n">jr</span> <span class="o">+=</span> <span class="n">NR</span><span class="p">)</span>
</code></pre></div></div>

<p>More on OpenMP <a href="https://ppc.cs.aalto.fi/ch2/openmp/">here</a>, <a href="https://ppc.cs.aalto.fi/ch3/">here</a> and <a href="https://curc.readthedocs.io/en/latest/programming/OpenMP-C.html">here</a>.</p>

<blockquote>
  <p>For many-core processors (&gt; 16 cores), consider utilizing nested parallelism and parallelizing 2-3 loops to increase the performance.</p>
</blockquote>

<p>Together with arithmetic operations, we will also parallelize the packing of both $\tilde{A}$ and $\tilde{B}$:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="n">pack_blockA</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">blockA_packed</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">mc</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">kc</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">)</span>
<span class="cp">#pragma omp parallel for num_threads(NTHREADS)
</span>  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">mc</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="n">MR</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="n">pack_blockB</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">blockB_packed</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">nc</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">kc</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span>
<span class="cp">#pragma omp parallel for num_threads(NTHREADS)
</span>  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">nc</span><span class="p">;</span> <span class="n">j</span> <span class="o">+=</span> <span class="n">NR</span><span class="p">)</span>
</code></pre></div></div>

<p>Similar to the second loop (and the first loop) around the micro-kernel, the packing loops can be efficiently parallelized due to the high number of iterations and the flexibility of choosing  $m_c, n_c$. For the multi-threaded implementation the values</p>

\[m_c = m_R \times \text{number of threads} \times 5\]

\[n_c = n_R \times \text{number of threads} \times 50\]

<p>provide the best performance on my machine, leading to the final optimized multi-threaded implementation <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_parallel.h">matmul_parallel.h</a></p>]]></content><author><name>Amanzhol Salykov</name></author><summary type="html"><![CDATA[A detailed blog post on optimizing multi-threaded matrix multiplication for x86 processors to achieve OpenBLAS/MKL-like performance. Tags: High-performance GEMM on CPU, Fast GEMM on CPU, High-performance matrix multiplication on CPU, Fast Matrix Multiplication on CPU, Matrix multiplication in C, GEMM in C, Matrix multiplication acceleration.]]></summary></entry><entry><title type="html">Advanced Matrix Multiplication Optimization on Modern Multi-Core Processors</title><link href="https://salykova.github.io/gemm-cpu" rel="alternate" type="text/html" title="Advanced Matrix Multiplication Optimization on Modern Multi-Core Processors" /><published>2024-08-01T09:00:01+00:00</published><updated>2024-08-01T09:00:01+00:00</updated><id>https://salykova.github.io/gemm-cpu</id><content type="html" xml:base="https://salykova.github.io/gemm-cpu"><![CDATA[<p><strong>TL;DR</strong> The code is available at <a href="https://github.com/salykova/sgemm.c">sgemm.c</a>. This blog post walks through optimizing multi-threaded FP32 matrix multiplication on modern processors using FMA3 and AVX2 vector instructions. The implementation delivers strong performance on a variety of x86-64 CPUs, both in single-threaded and multithreaded scenarios. However, to reach peak performance, you’ll need to fine-tune hyperparameters - such as the <em>number of threads, kernel size, and tile sizes</em>. Additionally, on AVX-512 CPUs, the BLAS libraries might be notably faster due to AVX-512 instructions, which were intentionally omitted here to support a broader range of processors. Performance results for Intel Core Ultra 265 and AMD Ryzen 7 9700X are shown below.</p>

<p><strong>P.S. Please feel free to get in touch if you are interested in collaborating. My contact information is available on the homepage.</strong></p>

<p><img src="/assets/matmul_cpu/intel_core_perf.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p><img src="/assets/matmul_cpu/amd_ryzen_perf.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<h2 id="1-introduction">1. Introduction</h2>

<p>Matrix multiplication is an essential part of nearly all modern neural networks. Despite using matmul daily in PyTorch, NumPy, or JAX, I’ve never really thought about how it is designed and implemented internally to take full advantage of hardware capabilities. NumPy, for instance, relies on external <a href="https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms">BLAS</a> (Basic Linear Algebra Subprograms) libraries. These libraries contain high-performance, optimized implementations of common linear algebra operations, such as the dot product, matrix multiplication, vector addition, and scalar multiplication. Examples of BLAS libraries include:</p>

<ol>
  <li><a href="https://en.wikipedia.org/wiki/Math_Kernel_Library">Intel MKL</a> - optimized for Intel CPUs</li>
  <li><a href="https://developer.apple.com/documentation/accelerate">Accelerate</a> - optimized for Apple CPUs</li>
  <li><a href="https://en.wikipedia.org/wiki/BLIS_(software)">BLIS</a> - open-source, multi-vendor, BLAS-like Library Instantiation Software</li>
  <li><a href="https://en.wikipedia.org/wiki/GotoBLAS">GotoBLAS</a> - open-source, multi-vendor</li>
  <li><a href="https://en.wikipedia.org/wiki/OpenBLAS">OpenBLAS</a> - open-source, multi-vendor, fork of GotoBLAS</li>
</ol>

<p>A closer look at the OpenBLAS <a href="https://github.com/OpenMathLib/OpenBLAS/blob/develop/kernel/x86_64/sgemm_kernel_8x4_haswell.c">code</a> reveals a mix of C and low-level assembly. In fact, OpenBLAS, GotoBLAS, and BLIS are written in C/FORTRAN/Assembly and contain matmul implementations manually optimized for different CPU microarchitectures.
My goal was to implement the matrix multiplication in pure C (without low-level assembly code) so that it works for any matrix size, runs on all modern x86-64 processors, and competes with existing BLAS libraries. At the sime time I wanted to keep the code simple and easy to extend. After some research, I found a few great step-by-step tutorials on implementing fast matrix multiplication from scratch, covering both theory and practice:</p>

<ol>
  <li><a href="https://siboehm.com/articles/22/Fast-MMM-on-CPU">Fast Multidimensional Matrix Multiplication on CPU from Scratch</a> by Simon Boehm.</li>
  <li><a href="https://en.algorithmica.org/hpc/algorithms/matmul/">Matrix Multiplication</a> by Sergey Slotin.</li>
  <li><a href="https://en.wikipedia.org/wiki/George_Hotz">Geohot’s</a> stream <a href="https://www.youtube.com/watch?v=VgSQ1GOC86s">Can you multiply a matrix?</a></li>
</ol>

<p>I highly recommend these clear and well-explained tutorials with alternative implementations. They helped me better understand the topic and, in some sense, motivated me to write my own implementation. The reason is that all three solutions above work only for specific matrix sizes and do not achieve performance of the BLAS libraries. Unsatisfied with these results, I kept researching and came across two fascinating papers: “<a href="https://www.cs.utexas.edu/~flame/pubs/GotoTOMS_final.pdf">Anatomy of High-Performance Matrix Multiplication</a>” and “<a href="https://www.cs.utexas.edu/~flame/pubs/blis3_ipdps14.pdf">Anatomy of High-Performance Many-Threaded Matrix Multiplication</a>”. The first introduces GotoBLAS, a high-performance BLAS implementation by <a href="https://en.wikipedia.org/wiki/Kazushige_Goto">Kazushige Goto</a>. The second reviews the matmul design used in the BLIS library (an extended version of GotoBLAS) and explores different parallelization strategies. Due to its superior high-level design, I had a feeling that the matmul implementation from the BLIS library can outperform existing BLAS implementations even if written in pure C and not manually finetuned using inline assembly. In the next chapters we’ll step-by-step implement the algorithm from scratch and compare against OpenBLAS. Before diving into optimizations, let’s first go over how to install OpenBLAS and properly benchmark the code on a CPU.</p>

<h2 id="2-how-to-install-and-benchmark-openblas">2. How to Install and Benchmark OpenBLAS</h2>

<p>I benchmarked the code on the following machine:</p>

<ul>
  <li>CPU: AMD Ryzen 7 9700X</li>
  <li>RAM: 32GB DDR5 6000 MHz CL36</li>
  <li>OpenBLAS 0.3.26</li>
  <li>Compiler: GCC 13.3</li>
  <li>Compiler flags: <code class="language-plaintext highlighter-rouge">-O3 -march=native -mno-avx512f -fopenmp</code></li>
  <li>OS: Ubuntu 24.04.1 LTS</li>
</ul>

<p><strong>Important!</strong> To obtain reproducible and accurate results, minimize the number of active tasks, particularly when benchmarking multi-threaded code. Windows systems generally deliver lower performance compared to Linux due to higher number of active background tasks.</p>

<p>To benchmark OpenBLAS, start by installing it according to the <a href="https://github.com/OpenMathLib/OpenBLAS/wiki/Installation-Guide">installation guide</a>. During installation, make sure to set an appropriate <code class="language-plaintext highlighter-rouge">TARGET</code> and disable AVX512 instructions for a fair comparison. For Zen4/5 processors compile OpenBLAS with:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>make <span class="nv">TARGET</span><span class="o">=</span>ZEN
</code></pre></div></div>

<p>Otherwise, OpenBLAS defaults to AVX512 instructions. After installation, you can run FP32 matmul using the OpenBLAS API:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;cblas.h&gt;</span><span class="cp">
</span><span class="n">cblas_sgemm</span><span class="p">(</span><span class="n">CblasColMajor</span><span class="p">,</span> <span class="n">CblasNoTrans</span><span class="p">,</span> <span class="n">CblasNoTrans</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">m</span><span class="p">);</span>
</code></pre></div></div>

<p>The benchmark evaluates the custom implementation and the OpenBLAS API on square matrices, ranging from <code class="language-plaintext highlighter-rouge">m=n=k=200</code> to <code class="language-plaintext highlighter-rouge">m=n=k=10000</code> in steps of <code class="language-plaintext highlighter-rouge">200</code>. To obtain consistent and accurate results, matrix multiplication is repeated <code class="language-plaintext highlighter-rouge">n_iter</code> times, and performance is measured as median execution time.</p>

<p>To multiply two <code class="language-plaintext highlighter-rouge">float32</code> matrices - $A$ of size $M \times K$ and $B$ of size $K \times N$, for each element of the resulting matrix $C$ of size $M \times N$, we need to compute the dot product between a row of $A$ and a column of $B$. This requires $K$ (additions) + $K$ (multiplications) = $2K$ Floating Point Operations (FLOP) per element of $C$ or $2MNK$ FLOP in total. A metric often used to evaluate matmul performance is called FLOP per second or FLOP/s or FLOPS, and it can be derived from the execution time as <code class="language-plaintext highlighter-rouge">FLOPS=FLOP/exec_time=(2*m*n*k)/exec_time</code>.</p>

<p><img src="/assets/matmul_cpu/matmul_naive.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<h2 id="3-theoretical-limit">3. Theoretical Limit</h2>

<p>The image below shows a simplified model of the computer’s memory hierarchy (for now, ignore the layers between the registers and the main memory(=RAM); we will discuss them later).</p>

<p><img src="/assets/matmul_cpu/mem_system_nc.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>To perform arithmetic operations on data stored in RAM (off-chip memory, slow and large capacity), the data must be first transferred to CPU and placed in CPU registers (on-chip memory, fast and small capacity). Modern x86-64 CPUs support SIMD (Single Instruction Multiple Data) extensions, which allow multiple pieces of data to be processed in parallel. There are various SIMD extensions, but the ones relevant to our discussion are <a href="https://en.wikipedia.org/wiki/Advanced_Vector_Extensions">Advanced Vector Extensions</a> (AVX2) and <a href="https://en.wikipedia.org/wiki/FMA_instruction_set">Fused Multiply-Add</a> (FMA). Both AVX2 and FMA operate on data stored in special 256-bit <code class="language-plaintext highlighter-rouge">YMM</code> registers. Each <code class="language-plaintext highlighter-rouge">YMM</code> register can hold 8 packed single-precision (32-bit) floats. The FMA2 instructions perform element-wise multiply-add operation on data stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers. The corresponding assembly instruction is called <code class="language-plaintext highlighter-rouge">VFMADD231PS</code> (PS stands for PackedSingle) and takes three vector registers (<code class="language-plaintext highlighter-rouge">YMM1</code>, <code class="language-plaintext highlighter-rouge">YMM2</code>, <code class="language-plaintext highlighter-rouge">YMM3</code>) as input to compute <code class="language-plaintext highlighter-rouge">YMM1 = YMM2 * YMM3 + YMM1</code>.</p>

<p><img src="/assets/matmul_cpu/fmadd.png" alt="" width="60%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>According to the <a href="https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html">intel intrinsics guide</a> or <a href="https://uops.info/table.html">https://uops.info/table.html</a>, for my CPU the throughput (TP) of the fused-multiply-add instruction is 0.5 cycles/instruction or with other words 2 instructions/cycle:
<img src="/assets/matmul_cpu/fmadd_uops.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Theoretically, Ryzen 9700X can perform 32 FLOP per cycle: 8 (floats in <code class="language-plaintext highlighter-rouge">YMM</code> register) * 2 (add + mul) * 2 (1/TP). Therefore, the theoretical peak FLOPS in single-threaded mode can be roughly estimated as <code class="language-plaintext highlighter-rouge">CPU_CLOCK_SPEED * 32</code> or <code class="language-plaintext highlighter-rouge">n_cores * CPU_CLOCK_SPEED * 32</code> in multi-threaded mode. For example, assuming a sustainable clock speed of 4.7 GHz for an 8-core 9700X processor, the theoretical peak FLOPS in a multi-threaded setting would be 1203 FLOPS.</p>

<h2 id="4-naive-implementation">4. Naive Implementation</h2>

<p>In this tutorial we assume that matrices are stored in column-major order: e.g. matrix <code class="language-plaintext highlighter-rouge">A</code> of shape <code class="language-plaintext highlighter-rouge">MxN</code> is stored as contiguous array of length <code class="language-plaintext highlighter-rouge">M*N</code> and an element <code class="language-plaintext highlighter-rouge">A[row][col]</code> is accessed via C raw pointer <code class="language-plaintext highlighter-rouge">ptr[col*M + row]</code>, where <code class="language-plaintext highlighter-rouge">0 &lt;= col &lt;= N-1</code> and <code class="language-plaintext highlighter-rouge">0 &lt;= row &lt;= M-1</code>.
<img src="/assets/matmul_cpu/mem_layout.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The simplest implementation of $C=AB$ can be described as follows:
<img src="/assets/matmul_cpu/matmul_naive.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">matmul_naive</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
      <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">p</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">p</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">p</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">C</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">A</span><span class="p">[</span><span class="n">p</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span> <span class="o">+</span> <span class="n">p</span><span class="p">];</span>
      <span class="p">}</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>
<p>Here, we iterate over all rows (the outermost loop) and all columns (the second loop) of <code class="language-plaintext highlighter-rouge">C</code> and for each element of <code class="language-plaintext highlighter-rouge">C</code> we calculate the dot product (the innermost loop) between the corresponding row of matrix <code class="language-plaintext highlighter-rouge">A</code> and column of matrix <code class="language-plaintext highlighter-rouge">B</code>. It’s always good to start with a simple and robust algorithm that can later be used to test optimized implementations.</p>

<h2 id="5-kernel">5. Kernel</h2>

<p>The key idea of high-performance matrix multiplication on CPU is to develop a function that efficiently computes a sub-matrix of $C$. Then, by iterating over $C$ and applying this function to all non-overlapping sub-matrices, we can significantly speed up the entire matrix multiplication operation. For this, we, first, partition the matrix $C$ of shape $M \times N$ into smaller non-overlapping sub-matrices of shape $m_R \times n_R$, with $n_R \ll N$ and $m_R \ll M$. To calculate $C=AB$, we iterate over $C$ and compute each of its non-overlapping $m_R \times n_R$ sub-matrices as shown below:</p>

<p><img src="/assets/matmul_cpu/matmul_kernel.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The function that computes an $m_R \times n_R$ sub-matrix $\bar{C}$ of $C$ is called a <strong>kernel</strong> (aka. <strong>micro-kernel</strong> using BLIS notation). This function is the core of high-performance matrix multiplication. When we say a matrix multiplication algorithm is optimized for a specific CPU architecture, it usually refers to kernel optimization. For example, OpenBLAS contains <a href="https://github.com/OpenMathLib/OpenBLAS/tree/develop/kernel">kernels</a> optimized for different CPU microarchitectures.</p>

<p>Let’s take a closer look at the kernel. To compute an $m_R \times n_R$ sub-matrix $\bar{C}$ of $C$, we need to multiply corresponding $m_R \times K$ sub-matrix $\bar{A}$ of $A$ with $K \times n_R$ sub-matrix $\bar{B}$ of $B$ as shown in the figure below:</p>

<p><img src="/assets/matmul_cpu/kernel.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>If we were to do this in a naive manner using the dot product, we would need to fetch $2K$ elements from RAM to calculate a single element of $\bar{C}$ or $2K m_R n_R$ elements in total to compute $\bar{C}$. There is, however, an alternative strategy that can reduce the number of fetched elements.</p>

<p>First, we initialize the matrix $\bar{C}$ with zeros and store it in registers. Since both $n_R$ and $m_R$ are small, the entire matrix fits within the registers. Here, the subscript $R$ in $n_R$ and $m_R$ denotes “register”. Next, we iterate over the dimension $K$, and in each iteration, we:</p>

<ol>
  <li>load 1 column of $\bar{A}$ and 1 row of $\bar{B}$ from RAM into the registers. Again, note that both the row and column vectors are limited in size and can be stored in the registers.</li>
  <li>compute the outer product between the two vectors and add the result of the outer product to the matrix $\bar{C}$.</li>
</ol>

<p><img src="/assets/matmul_cpu/kernel_rank.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>After $K$ iterations, the computation of the matrix $\bar{C}$ is completed and it can be stored into RAM. $\bar{C}$ is often referred to as the <em>accumulator</em>, because it accumulates the outer products along the dimension $K$. A single accumulation step of the outer product between two vectors is also known as <strong>rank-1 update</strong>.</p>

<blockquote>
  <p>Outer product between a column vector and a row vector.
<img src="/assets/matmul_cpu/outer_product.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>
</blockquote>

<p>In total, we fetch $(m_R + n_R)K$ elements from RAM into registers. Compared to the naive approach, this reduces the number of memory accesses by a factor of</p>

\[\frac{2m_Rn_RK}{(m_R + n_R)K} = \frac{2m_Rn_R}{m_R + n_R}\]

<p>This factor is maximized when both $m_R$, $n_R$ are large and equal. However, the values of $m_R$ and $n_R$ are typically constrained by the available register memory.</p>

<p>Now, let’s discuss in detail how the outer product and accumulation can be efficiently implemented using SIMD FMA instructions. Unfortunately, there are no SIMD instructions that compute the outer product in a single step. Therefore, we need to decompose the outer product into simpler operations. The figure below illustrates the process:</p>

<p><img src="/assets/matmul_cpu/kernel_registers.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Here, we compute the outer product between a column vector of size $m_R$ and a row vector of size $n_R$ to update an accumulator $\bar{C}$ of size $m_R \times n_R$. The accumulator is stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers, with each column of the accumulator spanning one or multiple <code class="language-plaintext highlighter-rouge">YMM</code> registers. The column vector is also stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers (highlighted as yellow). Since each <code class="language-plaintext highlighter-rouge">YMM</code> register holds 8 floats, the dimension $m_R$ must be divisible by 8. The accumulator is updated column by column. During the first iteration we broadcast the first element of the row vector to a vector of size $m_R$ and place it in the <code class="language-plaintext highlighter-rouge">YMM</code> registers (highlighted as green). Then, we element-wise multiply the column vector with the broadcasted vector and accumulate the result to the first column of the accumulator $\bar{C}$ using FMA instruction. We repeat this process for the remaining elements of the row vector to update the corresponding columns of the accumulator. After $n_R$ iterations, the rank-1 update of the accumulator is completed.</p>

<p>The last thing we need to discuss before implementing the kernel in C is how to choose the kernel size i.e. $m_R$ and $n_R$. CPUs with AVX support have <strong>16 YMM registers</strong>. From our previous discussion, we know that we need $(m_R/8) \cdot n_R$ registers to store the accumulator $\bar{C}$, $m_R/8$ registers to store the column vector and 1 register (because we can reuse the same register for all FMA operations) for the broadcasted vector. We want $m_R$ and $n_R$ to be as large as possible while satisfying the following conditions:</p>

<ul>
  <li>$\Big(\cfrac{m_R}{8} \cdot n_R + \cfrac{m_R}{8} + 1\Big) &lt;= 16$</li>
  <li>$m_R$ is a multiple of 8</li>
</ul>

<p>In theory we want $m_R = n_R$ to minimize the number of fetched elements. However, in practice, a non-square kernel with $m_R = 16, n_R = 6$ showed the best performance on my CPU. Therefore, we will implement this kernel in the next section. Feel free to experiment with other kernel sizes, such as $8 \times 8, 8 \times 12$, $8 \times 13$, $8 \times 14$, $32 \times 2$ and compare their performance on your CPU.</p>

<p>Let’s implement the algorithm discussed above using the $16 \times 6$ kernel. The code of this implementation can be found at <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_kernel.h">matmul_kernel.c</a>. To use SIMD instructions in C we first need to include the <code class="language-plaintext highlighter-rouge">immintin.h</code> library:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;immintrin.h&gt;</span><span class="cp">
</span></code></pre></div></div>

<p>The implementation of the algorithm is straightforward: we iterate over matrix $C$ and apply the kernel function to each of it’s non-overlapped $16 \times 6$ sub-matrices $\bar{C}$.</p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">matmul_kernel</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="mi">16</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">j</span> <span class="o">+=</span> <span class="mi">6</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">kernel_16x6</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span><span class="p">],</span> <span class="o">&amp;</span><span class="n">C</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">],</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">);</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The kernel function is declared as follows:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">kernel_16x6</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A_start</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B_start</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C_start</span><span class="p">,</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="kt">int</span> <span class="n">K</span><span class="p">);</span>
</code></pre></div></div>

<p>The function takes as input pointers to the starting positions of $\bar{A}, \bar{B}$, and $\bar{C}$ along with the matrix problem size. It then computes $16 \times 6$ sub-matrix $\bar{C}$ of $C$ according to $\bar{C} = \bar{A} \bar{B}$.</p>

<p>Inside the kernel function, first, we declare the variables stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__m256</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">6</span><span class="p">][</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="p">{};</span> <span class="c1">// zero-initialized</span>
<span class="n">__m256</span> <span class="n">b_packFloat8</span><span class="p">;</span>
<span class="n">__m256</span> <span class="n">a0_packFloat8</span><span class="p">;</span>
<span class="n">__m256</span> <span class="n">a1_packFloat8</span><span class="p">;</span>
</code></pre></div></div>

<p>A variable of type <code class="language-plaintext highlighter-rouge">__m256</code> is a 256-bit vector that represents the contents of a <code class="language-plaintext highlighter-rouge">YMM</code> register, which holds eight 32-bit floating-point values. <code class="language-plaintext highlighter-rouge">C_accum</code> is the accumulator stored in the <code class="language-plaintext highlighter-rouge">YMM</code> registers. The variable <code class="language-plaintext highlighter-rouge">b_packFloat8</code> contains a broadcasted element from a row vector of $\bar{B}$, while <code class="language-plaintext highlighter-rouge">a0_packFloat8</code> and <code class="language-plaintext highlighter-rouge">a1_packFloat8</code> represent a column vector of $\bar{A}$. Since the column vector contains 16 floats, it requires two <code class="language-plaintext highlighter-rouge">YMM</code> registers for storage.</p>

<p>SIMD intrinsics are well documented and can be found in the <a href="https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html">Intel Intrinsics Guide</a>. For example, <a href="https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_loadu_ps&amp;ig_expand=4100">_mm256_loadu_ps</a></p>

<p><img src="/assets/matmul_cpu/mm256_loadu.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>The kernel iterates over the dimension $K$ and in each iteration performs a rank-1 update of the accumulator:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">p</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">p</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">p</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
  <span class="c1">// Load column vector of size 16</span>
  <span class="c1">// {</span>
  <span class="n">a0_packFloat8</span> <span class="o">=</span> <span class="n">_mm256_loadu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A_start</span><span class="p">[</span><span class="n">p</span> <span class="o">*</span> <span class="n">M</span><span class="p">]);</span>
  <span class="n">a1_packFloat8</span> <span class="o">=</span> <span class="n">_mm256_loadu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A_start</span><span class="p">[</span><span class="n">p</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="mi">8</span><span class="p">]);</span>
  <span class="c1">// }</span>
  <span class="c1">// Broadcast scalar element to vector of size 8</span>
  <span class="c1">// {</span>
  <span class="n">b_packFloat8</span> <span class="o">=</span> <span class="n">_mm256_broadcast_ss</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B_start</span><span class="p">[</span><span class="n">p</span><span class="p">]);</span>
  <span class="c1">// }</span>
  <span class="c1">// Update the first column of the accumulator</span>
  <span class="c1">// {</span>
  <span class="n">C_accum</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_fmadd_ps</span><span class="p">(</span><span class="n">a0_packFloat8</span><span class="p">,</span> <span class="n">b_packFloat8</span><span class="p">,</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
  <span class="n">C_accum</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_fmadd_ps</span><span class="p">(</span><span class="n">a1_packFloat8</span><span class="p">,</span> <span class="n">b_packFloat8</span><span class="p">,</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
  <span class="c1">// }</span>
  <span class="p">...</span>
  <span class="p">...</span>
  <span class="p">...</span>
  <span class="n">b_packFloat8</span> <span class="o">=</span> <span class="n">_mm256_broadcast_ss</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B_start</span><span class="p">[</span><span class="mi">5</span> <span class="o">*</span> <span class="n">K</span> <span class="o">+</span> <span class="n">p</span><span class="p">]);</span>
  <span class="c1">// update the last column of the accumulator</span>
  <span class="c1">// {</span>
  <span class="n">C_accum</span><span class="p">[</span><span class="mi">5</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_fmadd_ps</span><span class="p">(</span><span class="n">a0_packFloat8</span><span class="p">,</span> <span class="n">b_packFloat8</span><span class="p">,</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">5</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
  <span class="n">C_accum</span><span class="p">[</span><span class="mi">5</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_fmadd_ps</span><span class="p">(</span><span class="n">a1_packFloat8</span><span class="p">,</span> <span class="n">b_packFloat8</span><span class="p">,</span> <span class="n">C_accum</span><span class="p">[</span><span class="mi">5</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
  <span class="c1">// }</span>
<span class="p">}</span>
</code></pre></div></div>

<p>After $K$ rank-1 updates, the computation of the accumulator is complete, and the result can be stored in RAM:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Store the accumulator column by column:</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="mi">6</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
  <span class="n">_mm256_storeu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
  <span class="n">_mm256_storeu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="mi">8</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Let’s take a look at the generated assembly code to see if it actually contains SIMD FMA instructions and uses the <code class="language-plaintext highlighter-rouge">YMM</code> registers:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>gcc <span class="nt">-O3</span> <span class="nt">-mno-avx512f</span> <span class="nt">-march</span><span class="o">=</span>native matmul_kernel.c <span class="nt">-S</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>// matmul_kernel.s
...
vfmadd231ps	%ymm14, %ymm1, %ymm13
vfmadd231ps	%ymm14, %ymm0, %ymm12
vmovaps	%ymm13, 32(%rsp)
vmovaps	%ymm12, 64(%rsp)
vbroadcastss	(%rax,%r9), %ymm14
vfmadd231ps	%ymm14, %ymm1, %ymm10
vfmadd231ps	%ymm14, %ymm0, %ymm11
vmovaps	%ymm10, 96(%rsp)
vmovaps	%ymm11, 128(%rsp)
vbroadcastss	(%rax,%r9,2), %ymm14
addq	$4, %rax
vfmadd231ps	%ymm14, %ymm1, %ymm2
vfmadd231ps	%ymm14, %ymm0, %ymm3
...
</code></pre></div></div>

<h2 id="6-padding">6. Padding</h2>

<p>You may have noticed that the current implementation only works for matrix sizes where $M$ and $N$ are multiples of $m_R$ and $n_R$, respectively. Specifically, the kernel assumes that matrix $\bar{C}$ has dimensions $m_R \times n_R$, matrix $\bar{A}$ is $m_R \times K$ and matrix $\bar{B}$ is $K \times n_R$. Our goal is to generalize the kernel so that it can handle matrices $\bar{C}, \bar{A}, \bar{B}$ with dimensions $m \times n, m \times K, K \times n$, even when $m \neq m_R$ and $n \neq n_R$, as shown below:</p>

<p><img src="/assets/matmul_cpu/kernel_mask.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>First, when storing the accumulator, we need to ensure that elements are only stored within the matrix boundaries. If the number of overlapping columns, $n$, is smaller than $n_R$, the process is straightforward - we simply iterate over $n$ columns instead of​ $n_R$:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// n - number of overlapped columns within C boundary</span>

<span class="c1">// "j &lt; n" instead "j &lt; 6", since n can be less than 6.</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
  <span class="n">_mm256_storeu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
  <span class="n">_mm256_storeu_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="mi">8</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The case where the number of overlapped rows $m$ differs from $m_R$ is a bit trickier because <code class="language-plaintext highlighter-rouge">_mm256_storeu_ps</code> stores 8 elements at once. Fortunately, <code class="language-plaintext highlighter-rouge">immintrin.h</code> library contains <code class="language-plaintext highlighter-rouge">_mm256_maskstore_ps</code> function, which stores packed floats according to mask values. The function takes <a href="https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-10/mm256-maskstore-ps-mm-maskstore-ps.html">three arguments</a> as input:</p>

<ol>
  <li><code class="language-plaintext highlighter-rouge">float *a</code></li>
  <li><code class="language-plaintext highlighter-rouge">__m256i mask</code></li>
  <li><code class="language-plaintext highlighter-rouge">__m256 b</code></li>
</ol>

<p><code class="language-plaintext highlighter-rouge">__m256i</code> is a vector datatype that holds eight 32-bit integers. Each integer in <code class="language-plaintext highlighter-rouge">mask</code> corresponds to a data element in <code class="language-plaintext highlighter-rouge">b</code>. The most significant bit (MSB) of each integer in <code class="language-plaintext highlighter-rouge">mask</code> represents the mask bit. If the mask bit is zero, the corresponding value in <code class="language-plaintext highlighter-rouge">b</code> is not stored in the memory location pointed to by <code class="language-plaintext highlighter-rouge">a</code>. For example, the MSB of unsigned integer <code class="language-plaintext highlighter-rouge">2147483648</code> (binary format <code class="language-plaintext highlighter-rouge">10000000 00000000 00000000 00000000</code>) is <code class="language-plaintext highlighter-rouge">1</code>, so the corresponding data element in <code class="language-plaintext highlighter-rouge">b</code> will be stored. On the other hand, the MSB of unsigned integer <code class="language-plaintext highlighter-rouge">2147483647</code> (binary format <code class="language-plaintext highlighter-rouge">01111111 11111111 11111111 11111111</code>) is <code class="language-plaintext highlighter-rouge">0</code>, meaning the corresponding data element in <code class="language-plaintext highlighter-rouge">b</code> will not be stored.</p>

<p>If $m \neq m_R$ , we generate integer masks by left-shifting unsigned integer <code class="language-plaintext highlighter-rouge">65535</code> (=<code class="language-plaintext highlighter-rouge">00000000 00000000 11111111 111111111</code> in binary format) depending on the number of overlapped rows $m$. In the code snippet below the function <code class="language-plaintext highlighter-rouge">_mm256_setr_epi32()</code> creates a <code class="language-plaintext highlighter-rouge">__m256i</code> vector from eight 32-bit integers.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__m256i</span> <span class="n">masks</span><span class="p">[</span><span class="mi">2</span><span class="p">];</span>
<span class="k">if</span> <span class="p">(</span><span class="n">m</span> <span class="o">!=</span> <span class="mi">16</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">const</span> <span class="kt">uint32_t</span> <span class="n">bit_mask</span> <span class="o">=</span> <span class="mi">65535</span><span class="p">;</span>
  <span class="n">masks</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_setr_epi32</span><span class="p">(</span><span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">15</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">14</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">13</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">12</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">11</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">10</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">9</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">8</span><span class="p">));</span>
  <span class="n">masks</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">_mm256_setr_epi32</span><span class="p">(</span><span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">7</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">6</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">5</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">4</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">3</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">2</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="p">(</span><span class="n">m</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
                               <span class="n">bit_mask</span> <span class="o">&lt;&lt;</span> <span class="n">m</span><span class="p">);</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">n</span><span class="p">;</span> <span class="n">j</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">_mm256_maskstore_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span><span class="p">],</span> <span class="n">masks</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]);</span>
    <span class="n">_mm256_maskstore_ps</span><span class="p">(</span><span class="o">&amp;</span><span class="n">C_start</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="mi">8</span><span class="p">],</span> <span class="n">masks</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">C_accum</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">1</span><span class="p">]);</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The compiler auto-vectorizes the sequential bit-shifting operations using a combination of <code class="language-plaintext highlighter-rouge">vpaddd</code> and <code class="language-plaintext highlighter-rouge">vpsllvd</code> instructions, making the mask computation very efficient. There is, however, an alternative method to compute the masks, as will be shown later.</p>

<p>When loading elements from matrices $\bar{A}$ and $\bar{B}$ inside the kernel, we need to check that the loads are within the matrix boundaries. One way to do this is by using <code class="language-plaintext highlighter-rouge">_mm256_maskload_ps</code> when loading elements from the matrix $\bar{A}$ and looping over $n$ elements instead of $n_R$ when loading elements from the matrix $\bar{B}$. However, this method would significantly degrade the kernel’s performance. The additional instructions required to compute the loading masks introduce overhead, and since $n$ is not a compile-time constant, the compiler cannot unroll the loop efficiently. Instead, if $m \neq m_R$, we copy the matrix $\bar{A}$ into a buffer, pad it with zeros and pass the padded matrix of size $m_R \times K$ to the kernel. We do the same for the matrix $\bar{B}$ if $n \neq n_R$. The implementation straightforwardly follows the description:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#define BLOCK_A_MAXSIZE 500000
#define BLOCK_B_MAXSIZE 200000
</span>
<span class="k">static</span> <span class="kt">float</span> <span class="n">blockA_buffer</span><span class="p">[</span><span class="n">BLOCK_A_MAXSIZE</span><span class="p">]</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">aligned</span><span class="p">(</span><span class="mi">64</span><span class="p">)));</span>
<span class="k">static</span> <span class="kt">float</span> <span class="n">blockB_buffer</span><span class="p">[</span><span class="n">BLOCK_B_MAXSIZE</span><span class="p">]</span> <span class="n">__attribute__</span><span class="p">((</span><span class="n">aligned</span><span class="p">(</span><span class="mi">64</span><span class="p">)));</span>

<span class="kt">void</span> <span class="nf">matmul_pack</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="mi">16</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">const</span> <span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="n">M</span> <span class="o">-</span> <span class="n">i</span><span class="p">);</span>
        <span class="kt">float</span><span class="o">*</span> <span class="n">blockA</span> <span class="o">=</span> <span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
        <span class="kt">int</span> <span class="n">blockA_ld</span> <span class="o">=</span> <span class="n">M</span><span class="p">;</span>
        <span class="k">if</span> <span class="p">(</span><span class="n">m</span> <span class="o">!=</span> <span class="mi">16</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">pack_blockA</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">blockA_buffer</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">);</span>
            <span class="n">blockA</span> <span class="o">=</span> <span class="n">blockA_buffer</span><span class="p">;</span>
            <span class="n">blockA_ld</span> <span class="o">=</span> <span class="mi">16</span><span class="p">;</span>
        <span class="p">}</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">j</span> <span class="o">+=</span> <span class="mi">6</span><span class="p">)</span> <span class="p">{</span>
            <span class="k">const</span> <span class="kt">int</span> <span class="n">n</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="n">N</span> <span class="o">-</span> <span class="n">j</span><span class="p">);</span>
            <span class="kt">float</span><span class="o">*</span> <span class="n">blockB</span> <span class="o">=</span> <span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span><span class="p">];</span>
            <span class="k">if</span> <span class="p">(</span><span class="n">n</span> <span class="o">!=</span> <span class="mi">6</span><span class="p">)</span> <span class="p">{</span>
                <span class="n">pack_blockB</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span><span class="p">],</span> <span class="n">blockB_buffer</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">);</span>
                <span class="n">blockB</span> <span class="o">=</span> <span class="n">blockB_buffer</span><span class="p">;</span>
            <span class="p">}</span>
            <span class="n">kernel_16x6</span><span class="p">(</span><span class="n">blockA</span><span class="p">,</span> <span class="n">blockB</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">C</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">],</span> <span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">blockA_ld</span><span class="p">);</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>For further implementations details, please check <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_pad.h">matmul_pad.h</a></p>

<h2 id="7-cache-blocking">7. Cache Blocking</h2>

<p>Let’s revisit the computer’s memory hierarchy. Previously, we focused on the main memory (DRAM) and the CPU registers, but we skipped an important intermediary: the CPU cache system.</p>

<p><img src="/assets/matmul_cpu/mem_system.png" alt="" width="80%" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Unlike DRAM, the CPU cache is an on-chip memory designed to store frequently and/or recently accessed data from the main memory. This helps minimize data transfers between the main memory and CPU registers. Although the cache is much faster than DRAM, it has a limited storage capacity. To optimize data access, modern desktop CPUs use a multi-level cache hierarchy. This typically includes L1, L2, and L3 caches, each offering progressively larger storage but with increasing access times. L1 cache is the fastest and closest to the CPU core.</p>

<p><img src="/assets/matmul_cpu/cpu_arch.png" alt="" /></p>

<p><img src="/assets/matmul_cpu/core_arch.png" alt="" /></p>

<p style="display:block; margin-left:auto; margin-right:auto; text-align: center"><em>Intel Core i9-13900K labelled die shot. Source: <a href="https://www.youtube.com/watch?v=dX9CGRZwD-w">How are Microchips Made?</a></em></p>

<p>To improve access speed, CPUs transfer data between main memory and cache in fixed-size chunks called <strong>cache lines</strong> or <strong>cache blocks</strong>. When a cache line is loaded from main memory, it is stored as a cache entry. For example, in AMD Ryzen Zen CPUs, the cache line size is <a href="https://en.wikichip.org/wiki/amd/microarchitectures/zen_4#Memory_Hierarchy">64 bytes</a>. The cache takes advantage of data locality - how programs typically access memory. When a single floating-point number is requested from a continuous array in memory, the cache doesn’t just fetch that one value; it also preloads the next floating-point numbers and stores them in the cache. This is why reading data sequentially from an array is much more efficient than randomly accessing scattered memory locations. When the CPU needs to read or write to a memory location, it first checks if the data is already in the cache. This leads to two possible scenarios:</p>

<ol>
  <li><strong>Cache Hit</strong> - If the requested memory location is found in the cache, the CPU can access it instantly, avoiding the need to fetch data from the much slower DRAM.</li>
  <li><strong>Cache Miss</strong> - If the requested data is not in the cache, the CPU retrieves it from the main memory and stores it in the cache for future access.</li>
</ol>

<p>Since the cache has limited space, it must decide which data to replace when new information needs to be stored. This decision is governed by a <a href="https://en.wikipedia.org/wiki/Cache_replacement_policies">cache replacement policy</a>. Some of the most common policies include:</p>

<ol>
  <li><strong>LRU</strong> (Least Recently Used): Replaces the cache entry that has gone unused the longest.</li>
  <li><strong>LFU</strong> (Least Frequently Used): Evicts the entry that has been accessed the least often.</li>
  <li><strong>LFRU</strong> (Least Frequently Recently Used): A hybrid approach that considers both recent and overall access frequency.</li>
</ol>

<p>Similar to registers, once data is loaded into the cache, we want to reuse the data as much as possible to reduce main memory accesses. Given the cache’s limited capacity, storing entire input matrices $C, B, A$  in the cache isn’t feasible. Instead, we divide them into smaller blocks, load these blocks into the cache, and reuse them for rank-1 updates. This technique is often referred to as <strong>tiling</strong> or <strong>cache blocking</strong>, allowing us to handle matrices of arbitrary size effectively.</p>

<p>The single-threaded matrix multiplication with cache blocking can be visualized as shown in the image borrowed from the official <a href="https://github.com/flame/blis/blob/master/docs/Multithreading.md">BLIS repository</a>:</p>

<p><img src="/assets/matmul_cpu/blis_design.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>Let’s step through the diagram and discuss it.
In the outer-most loop (5th loop) we iterate over dimension $N$, dividing matrix $C$ into blocks $C_j$ of size $M \times n_c$  and matrix $B$  into blocks $B_j$ of size $K \times n_c$. The subscript $c$ in $n_c$ stands for <em>cache</em>.
In the 4th loop we iterate over dimension $K$ and divide matrix $A$ into $A_j$ of size $M \times k_c$  and $B_j$ into $B_p$ of size $k_c \times n_c$. Notice $B_p$ has fixed, limited size and can now be loaded into the cache. $B_p$ is packed into $\tilde{B}_p$, padded with zeros, if necessary, and loaded into the L3 cache. I
In the 3rd loop we iterate over dimension $M$ and divide $C_j$ into $C_i$ (there is a typo in the diagram) of size $m_c \times n_c$ and $A_p$  into $A_j$ of size $m_c \times k_c$. Matrix $A_j$ is now restricted in size and can be loaded entirely into the L2 cache. $A_j$ is packed into $\tilde{A}_j$ and padded with zeros if needed. Note how we reuse the same $\tilde{B}_p$ block from the L3 cache for different $A_j$ blocks. Both $m_c$ and $n_c$ are chosen to be a multiple of $m_R$ and $n_R$ respectively.</p>

<p>In the last two loops we simply iterate over cached blocks and divide them into $m_R \times k_c$ and $k_c \times n_R$ panels. These panels are then passed to the kernel to perform rank-1 updates on the $m_R \times n_R$ sub-matrix of $C$, similarly to what we have already done in the previous chapter. Each panel of $\tilde{B}_p$ is loaded into the L1 cache and reused for multiple panels of $\tilde{A}_j$.
Keep in mind that $\tilde{A}_j$ and $\tilde{B}_p$ are packed differently. During rank-1 updates we sequentially read a panel of $\tilde{A}_j$ column by column and a panel of $\tilde{B}_p$ row by row. Thus,  each panel inside $\tilde{A}_j$ is stored in column-major order, while each panel inside $\tilde{B}_p$ is stored in row-major order.</p>

<p>Different CPU models have different cache sizes. To achieve peak performance, it’s crucial to optimize three key parameters: cache sizes for L1, L2, and L3 cashes (represented by $k_c$​, $m_c$​, and $n_c$​ respectively). Theoretically, these parameters should be chosen so that:</p>

<ul>
  <li>$k_c​ \times n_c$​ fills the entire L3 cache.</li>
  <li>$m_c​ \times k_c​$ fills the entire L2 cache.</li>
  <li>$k_c​ \times n_R$​ fills the entire L1 cache.</li>
</ul>

<p>While these values provide a good starting point, using larger values often leads to better performance in practice. Unfortunately (or fortunately), we cannot manually place data into the cache or control which cache levels store the data; the CPU manages this automatically using cache replacement policies. Therefore, cache blocking and cache reuse must be implemented at the algorithm level through, for example, well-designed loops and strategic data access patterns.</p>

<p>The implementation <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_cache.h">matmul_cache.h</a> straightforwardly follows the algorithm depicted in the diagram:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">matmul_cache</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">j</span> <span class="o">+=</span> <span class="n">NC</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">const</span> <span class="kt">int</span> <span class="n">nc</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">NC</span><span class="p">,</span> <span class="n">N</span> <span class="o">-</span> <span class="n">j</span><span class="p">);</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">p</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">p</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">p</span> <span class="o">+=</span> <span class="n">KC</span><span class="p">)</span> <span class="p">{</span>
      <span class="k">const</span> <span class="kt">int</span> <span class="n">kc</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">KC</span><span class="p">,</span> <span class="n">K</span> <span class="o">-</span> <span class="n">p</span><span class="p">);</span>
      <span class="n">pack_blockB</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">j</span> <span class="o">*</span> <span class="n">K</span> <span class="o">+</span> <span class="n">p</span><span class="p">],</span> <span class="n">blockB_packed</span><span class="p">,</span> <span class="n">nc</span><span class="p">,</span> <span class="n">kc</span><span class="p">,</span> <span class="n">K</span><span class="p">);</span>
      <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="n">MC</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">const</span> <span class="kt">int</span> <span class="n">mc</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">MC</span><span class="p">,</span> <span class="n">M</span> <span class="o">-</span> <span class="n">i</span><span class="p">);</span>
        <span class="n">pack_blockA</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">p</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="n">i</span><span class="p">],</span> <span class="n">blockA_packed</span><span class="p">,</span> <span class="n">mc</span><span class="p">,</span> <span class="n">kc</span><span class="p">,</span> <span class="n">M</span><span class="p">);</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">jr</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">jr</span> <span class="o">&lt;</span> <span class="n">nc</span><span class="p">;</span> <span class="n">jr</span> <span class="o">+=</span> <span class="n">NR</span><span class="p">)</span> <span class="p">{</span>
          <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">ir</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">ir</span> <span class="o">&lt;</span> <span class="n">mc</span><span class="p">;</span> <span class="n">ir</span> <span class="o">+=</span> <span class="n">MR</span><span class="p">)</span> <span class="p">{</span>
            <span class="k">const</span> <span class="kt">int</span> <span class="n">mr</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">MR</span><span class="p">,</span> <span class="n">mc</span> <span class="o">-</span> <span class="n">ir</span><span class="p">);</span>
            <span class="k">const</span> <span class="kt">int</span> <span class="n">nr</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">NR</span><span class="p">,</span> <span class="n">nc</span> <span class="o">-</span> <span class="n">jr</span><span class="p">);</span>
            <span class="n">kernel_16x6</span><span class="p">(</span><span class="o">&amp;</span><span class="n">blockA_packed</span><span class="p">[</span><span class="n">ir</span> <span class="o">*</span> <span class="n">kc</span><span class="p">],</span> <span class="o">&amp;</span><span class="n">blockB_packed</span><span class="p">[</span><span class="n">jr</span> <span class="o">*</span> <span class="n">kc</span><span class="p">],</span> <span class="o">&amp;</span><span class="n">C</span><span class="p">[(</span><span class="n">j</span> <span class="o">+</span> <span class="n">jr</span><span class="p">)</span> <span class="o">*</span> <span class="n">M</span> <span class="o">+</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="n">ir</span><span class="p">)],</span> <span class="n">mr</span><span class="p">,</span> <span class="n">nr</span><span class="p">,</span> <span class="n">kc</span><span class="p">,</span> <span class="n">M</span><span class="p">);</span>
          <span class="p">}</span>
        <span class="p">}</span>
      <span class="p">}</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="8-kernel-micro-optimizations">8. Kernel Micro-Optimizations</h2>

<p>Instead of using arrays of <code class="language-plaintext highlighter-rouge">__m256</code> to define the accumulator $\bar{C}$ and the masks</p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__m256</span> <span class="n">C_buffer</span><span class="p">[</span><span class="mi">6</span><span class="p">][</span><span class="mi">2</span><span class="p">];</span>
<span class="n">__m256i</span> <span class="n">masks</span><span class="p">[</span><span class="mi">2</span><span class="p">];</span>
</code></pre></div></div>
<p>we explicitly unroll them</p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="n">__m256</span> <span class="n">C00</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C10</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C01</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C11</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C02</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C12</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C03</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C13</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C04</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C14</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C05</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256</span> <span class="n">C15</span> <span class="o">=</span> <span class="n">_mm256_setzero_ps</span><span class="p">();</span>
    <span class="n">__m256i</span> <span class="n">packed_mask0</span><span class="p">;</span>
    <span class="n">__m256i</span> <span class="n">packed_mask1</span><span class="p">;</span>
</code></pre></div></div>
<p>By doing this, GCC can better optimize the code avoiding register spilling. Additionally, we use vector instructions to calculate the masks as follows:</p>
<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">static</span> <span class="kt">int8_t</span> <span class="n">mask</span><span class="p">[</span><span class="mi">32</span><span class="p">]</span>
    <span class="n">__attribute__</span><span class="p">((</span><span class="n">aligned</span><span class="p">(</span><span class="mi">64</span><span class="p">)))</span> <span class="o">=</span> <span class="p">{</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
                                    <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">,</span>  <span class="mi">0</span><span class="p">};</span>
<span class="n">packed_mask0</span> <span class="o">=</span> <span class="n">_mm256_cvtepi8_epi32</span><span class="p">(</span><span class="n">_mm_loadu_si64</span><span class="p">(</span><span class="o">&amp;</span><span class="n">mask</span><span class="p">[</span><span class="mi">16</span> <span class="o">-</span> <span class="n">mr</span><span class="p">]));</span>
<span class="n">packed_mask1</span> <span class="o">=</span> <span class="n">_mm256_cvtepi8_epi32</span><span class="p">(</span><span class="n">_mm_loadu_si64</span><span class="p">(</span><span class="o">&amp;</span><span class="n">mask</span><span class="p">[</span><span class="mi">16</span> <span class="o">-</span> <span class="n">mr</span> <span class="o">+</span> <span class="mi">8</span><span class="p">]));</span>
</code></pre></div></div>

<p>The corresponding implementation can be found at <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_micro.h">matmul_micro.h</a></p>

<h2 id="9-multithreading">9. Multithreading</h2>

<p>There are indeed many loops that can be potentially parallelized. To achieve high-performance, we want to parallelize both packing and arithmetic operations. Let’s start with the arithmetic operations. The 5th, 4th, 3rd loops around the micro-kernel iterate over matrix dimensions in chunks of cache block sizes $n_c$, $k_c$, $m_c$. To efficiently parallelize the loops and keep all threads busy, we want number of iterations (=matrix dimension / cache block size) to be at least = number of threads (generally, the more the better). In other words, the input matrix dimension should be at least = number of threads  * cache block size. As we discussed earlier, we also want cache blocks to fully occupy the corresponding cache levels. On modern CPUs, the second requirement results in cache block sizes of thousand(s) of elements. For example, on my Ryzen 9700X, cache block sizes of $n_c=1535$, $m_c=1024$ attain the best performance in the single-threaded scenario. Given the number of available cores on Ryzen 9700X, we need input matrices with dimensions of at least $\max(m_c, n_c) \times \text{number of cores} = 1535 \times 8 = 12280$ to be able to distribute the work over all cores.</p>

<p><img src="/assets/matmul_cpu/blis_design.png" alt="" style="display:block; margin-left:auto; margin-right:auto" /></p>

<p>In contrast, the last two loops iterate over cache blocks, dividing them into $m_R, n_R$ blocks. Since $n_R, m_R$ are typically very small (&lt;20), these loops are ideal candidates for parallelization. Moreover, we can choose $m_c, n_c$ to be multiples of number of cores so that the work is evenly distributed across all cores.</p>

<p>On my machine, parallelizing the second and first inner loops jointly with <code class="language-plaintext highlighter-rouge">collapse(2)</code> results in the best performance:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#pragma omp parallel for collapse(2) num_threads(NTHREADS)
</span>  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">jr</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">jr</span> <span class="o">&lt;</span> <span class="n">nc</span><span class="p">;</span> <span class="n">jr</span> <span class="o">+=</span> <span class="n">NR</span><span class="p">)</span>
</code></pre></div></div>

<p>More on OpenMP <a href="https://ppc.cs.aalto.fi/ch2/openmp/">here</a>, <a href="https://ppc.cs.aalto.fi/ch3/">here</a> and <a href="https://curc.readthedocs.io/en/latest/programming/OpenMP-C.html">here</a>.</p>

<blockquote>
  <p>For many-core processors (&gt; 16 cores), consider utilizing nested parallelism and parallelizing 2-3 loops to increase the performance.</p>
</blockquote>

<p>Together with arithmetic operations, we will also parallelize the packing of both $\tilde{A}$ and $\tilde{B}$:</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="n">pack_blockA</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">blockA_packed</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">mc</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">kc</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">M</span><span class="p">)</span>
<span class="cp">#pragma omp parallel for num_threads(NTHREADS)
</span>  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">mc</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="n">MR</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="n">pack_blockB</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">blockB_packed</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">nc</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">kc</span><span class="p">,</span> <span class="k">const</span> <span class="kt">int</span> <span class="n">K</span><span class="p">)</span>
<span class="cp">#pragma omp parallel for num_threads(NTHREADS)
</span>  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">nc</span><span class="p">;</span> <span class="n">j</span> <span class="o">+=</span> <span class="n">NR</span><span class="p">)</span>
</code></pre></div></div>

<p>Similar to the second loop (and the first loop) around the micro-kernel, the packing loops can be efficiently parallelized due to the high number of iterations and the flexibility of choosing  $m_c, n_c$. For the multi-threaded implementation the values</p>

\[m_c = m_R \times \text{number of threads} \times 5\]

\[n_c = n_R \times \text{number of threads} \times 50\]

<p>provide the best performance on my machine, leading to the final optimized multi-threaded implementation <a href="https://github.com/salykova/sgemm.c/blob/main/tutorial/matmul_parallel.h">matmul_parallel.h</a></p>]]></content><author><name>Amanzhol Salykov</name></author><summary type="html"><![CDATA[A detailed blog post on optimizing multi-threaded matrix multiplication for x86 processors to achieve OpenBLAS/MKL-like performance. Tags: High-performance GEMM on CPU, Fast GEMM on CPU, High-performance matrix multiplication on CPU, Fast Matrix Multiplication on CPU, Matrix multiplication in C, GEMM in C, Matrix multiplication acceleration.]]></summary></entry></feed>