-
Notifications
You must be signed in to change notification settings - Fork 857
Closed
Description
Continuing from existing issues about this kind of workload (e.g. #14951). The workload to be optimized is as given in @Max191's gist here: https://gist.github.com/Max191/8ae398f4697edd31bd07a613b46dcbbd
It can be described as:
- A first
linalg.genericeffectively amounts to some transposes and alinalg.batch_matmul, with element types:i16xi4->i32. - A second
linalg.generictakes the output of the firstgenericand performs some arithmetic dequantizing tof32and does add-reduction on what was the batch dimension in the firstgeneric.
@Max191 has pushed the performance of this so far by improving default non-data-tiled non-ukernel codegen as much as possible, adding VectorContractCustomKernels as needed.
This Issue is about doing a V2 of this switching to data-tiling and ukernels.
The work could break down into these pieces:
graph TD
a1["Update EncodingAttr to have a type tuple
instead of encoding that in enums (#15182)"]
a3_0["@NatashaKnk's ExpandVectors pattern (#15273) expanding matvec ops to matmul"]
a3_1["Add missing `batch_vecmat` linalg named op"]
a3_2["Update ExpandVectors to support `batch_vecmat`"]
a3_3["Lift that first `linalg.generic` to `arith.ext*`,
`linalg.batch_matmul` and transposes (#15339)"]
a4_0["SetEncoding for `arith.ext + batch_matmul`"]
a4_0_1["MaterializeEncoding for `arith.ext + batch_matmul`"]
a4_1["Verify that `batch_mmt4d` lowers to a loop of `mmt4d`"]
a4_2["LowerToUkernel for `arith.ext + mmt4d`"]
a5_0["ukernels: stop abusing signless as signed"]
a5_1["ukernels: sub-8bit support, and
generic mmt4d support for s16u4 and s16s16"]
a5_2["ukernels: optimized s16u4 and s16s16
tile functions for x86 and arm"]
a6_0["MaterializeEncoding refactor to
enable custom vecmat tiles
instead of just truncating
a generic matmul tile"]
a6["ukernels: optimized `s16i4`
tile function for vecmat with
custom vecmat-tuned tile"]
a7["Add data-tiling tile-selection logic for `i16xi4`"]
a8["Ensure that fusions work as intended"]
a8_1["Ensure that const-eval works as intended"]
a9["Benchmark and study e2e"]
a1-->a7-->a9
a3_0-->a3_2-->a3_3-->a4_0-->a4_0_1-->a4_1-->a4_2-->a9
a4_0_1-->a8-->a9
a4_0_1-->a8_1-->a9
a3_1-->a3_2
a5_0-->a5_1-->a5_2-->a6-->a7
a6_0-->a6
subgraph Legend
notstartedLegend["Not Started"]
inprogressLegend["In Progress"]
doneLegend["Done"]
end
classDef notstarted fill:#ddd,text:#000
classDef inprogress fill:#ff6,text:#000
classDef done fill:#9f9,text:#000
class notstartedLegend notstarted
class inprogressLegend,a5_2,a8,a8_1,a9 inprogress
class doneLegend,a1,a3_0,a3_1,a3_2,a3_3,a5_0,a5_1,a4_0,a6_0,a4_2,a4_0_1,a4_1,a6,a7,a4_1 done
style Legend fill:#fff, stroke:#000
Reactions are currently unavailable