Skip to content

Data-tiling and Ukernels for i16xi4->i32 group-quantized matmuls. #15158

@bjacob

Description

@bjacob

Continuing from existing issues about this kind of workload (e.g. #14951). The workload to be optimized is as given in @Max191's gist here: https://gist.github.com/Max191/8ae398f4697edd31bd07a613b46dcbbd

It can be described as:

  • A first linalg.generic effectively amounts to some transposes and a linalg.batch_matmul, with element types: i16xi4->i32.
  • A second linalg.generic takes the output of the first generic and performs some arithmetic dequantizing to f32 and does add-reduction on what was the batch dimension in the first generic.

@Max191 has pushed the performance of this so far by improving default non-data-tiled non-ukernel codegen as much as possible, adding VectorContractCustomKernels as needed.

This Issue is about doing a V2 of this switching to data-tiling and ukernels.

The work could break down into these pieces:

graph TD
a1["Update EncodingAttr to have a type tuple
instead of encoding that in enums (#15182)"]

a3_0["@NatashaKnk's ExpandVectors pattern (#15273) expanding matvec ops to matmul"]

a3_1["Add missing `batch_vecmat` linalg named op"]

a3_2["Update ExpandVectors to support `batch_vecmat`"]

a3_3["Lift that first `linalg.generic` to `arith.ext*`,
`linalg.batch_matmul` and transposes (#15339)"]

a4_0["SetEncoding for `arith.ext + batch_matmul`"]

a4_0_1["MaterializeEncoding for `arith.ext + batch_matmul`"]

a4_1["Verify that `batch_mmt4d` lowers to a loop of `mmt4d`"]

a4_2["LowerToUkernel for `arith.ext + mmt4d`"]

a5_0["ukernels: stop abusing signless as signed"]

a5_1["ukernels: sub-8bit support, and
generic mmt4d support for s16u4 and s16s16"]

a5_2["ukernels: optimized s16u4 and s16s16
tile functions for x86 and arm"]

a6_0["MaterializeEncoding refactor to
enable custom vecmat tiles
instead of just truncating
a generic matmul tile"]

a6["ukernels: optimized `s16i4`
tile function for vecmat with
custom vecmat-tuned tile"]

a7["Add data-tiling tile-selection logic for `i16xi4`"]

a8["Ensure that fusions work as intended"]

a8_1["Ensure that const-eval works as intended"]

a9["Benchmark and study e2e"]

a1-->a7-->a9
a3_0-->a3_2-->a3_3-->a4_0-->a4_0_1-->a4_1-->a4_2-->a9
a4_0_1-->a8-->a9
a4_0_1-->a8_1-->a9
a3_1-->a3_2
a5_0-->a5_1-->a5_2-->a6-->a7
a6_0-->a6

subgraph Legend
  notstartedLegend["Not Started"]
  inprogressLegend["In Progress"]
  doneLegend["Done"]
end

classDef notstarted fill:#ddd,text:#000
classDef inprogress fill:#ff6,text:#000
classDef done fill:#9f9,text:#000

class notstartedLegend notstarted
class inprogressLegend,a5_2,a8,a8_1,a9 inprogress
class doneLegend,a1,a3_0,a3_1,a3_2,a3_3,a5_0,a5_1,a4_0,a6_0,a4_2,a4_0_1,a4_1,a6,a7,a4_1 done

style Legend fill:#fff, stroke:#000

Loading

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions