Skip to content

ggml: Add initial MXFP6 CPU implementation#22671

Open
michaelw9999 wants to merge 3 commits into
ggml-org:masterfrom
michaelw9999:mxfp6-cpu
Open

ggml: Add initial MXFP6 CPU implementation#22671
michaelw9999 wants to merge 3 commits into
ggml-org:masterfrom
michaelw9999:mxfp6-cpu

Conversation

@michaelw9999

@michaelw9999 michaelw9999 commented May 4, 2026

Copy link
Copy Markdown
Contributor

I'm bringing in this PR an initial implementation with basic MXFP6-E2M3 support for CPU only.
MXFP6-E3M2 is another FP6 variant; support for that can be added later.
Native CUDA/Blackwell support is working and is intended to follow after this PR. It is possible to create native AMD ROCm versions in the future, as well as versions for other backends.
The first model is on HF here: Qwen3.5-4B-MXFP6-GGUF. More are ready and will be uploaded soon.
Related discussion here: #22498
This PR does not include any quantizer besides that used for reference and testing.

Why add MXFP6 into llama.cpp now?

Details
    • Presently nobody is using MXFP6; but that is because there is no real support for it, so nobody can. Essentially no models exist in MXFP6 on HF, nobody can make them easily, and nobody can run them fast. vLLM supports MXFP6 in a limited way (not Blackwell fast), and AMD provides quantization support with Quark and hardware accelerated MXFP6 on some of its latest datacenter GPUs. More are likely coming in the future.
      However, essentially there isn't an implementation anywhere that is fast or optimized for it, so it not feasible to use the format (until now). NVIDIA Blackwell GPUs have native hardware blockscaling acceleration - but no kernels with native hardware support were ever released by anyone (as far as I know).
    • ModelOpt support for MXFP6 is still limited and described as "simulated". I have created a heavily modified llama-quantizer that optimizes and auto-tunes for NVFP4 and have adapted it for MXFP6 as well. No need to convert from HF or wait for models to never get made by other platforms, we can make as many as we want just for llama.cpp. Using a scale search and tuning optimizations, the accuracy/quality of MXFP6 is excellent and much superior to NVFP4, and still almost as fast, and this is already before optimization and improvements and any llama.cpp community help. To improve MXFP6 quality, the quantizer can create optional per tensor scales just like NVFP4. The existing llama.cpp machinery will already incorporate them into the graph without a single line of code needed. There is no speed difference and my testing shows it really does improve the quality of MXFP6 even further.
    • AMD's studies showed the best results from combining FP4 and FP6. With full NVFP4 support now on llama.cpp, I've compared MXFP4+MXFP6 vs NVFP4+MXFP6 by choosing which layer gets which type (some tensors in BF16, F32, Q8, etc). The quantizer chooses automatically based on the error and using imatrix, then determines which is better suited, but it can all be force determined too. Well quantized NVFP4 combined with MXFP6 in select layers brings significant opportunities to designing a model's layout: moving just a small portion of a mostly NVFP4 model to MXFP6 has almost no reduction in speed or increase in size, but can significantly improve quality. Conversely, on a mostly MXFP6 version, moving just a few layers into NVFP4 can bring increased speed with just a marginal loss of quality. Further testing will be needed to figure out exactly the right balance. This flexibility opens up a opportunity for selectively tuning based on GPU VRAM targets. For example, on Qwen3.6-35B-A3B-GGUF Q6_K is 28.5GB and has only slightly better quality. An MXFP6 model is approximately 27GB and is faster on prefill, already with very early not-well tuned kernels.

Qwen3.5-4B PPL/KLD/Speed Report:

Details

CPU final PPL vs BF16 had a difference of -0.012933
(CPU ): Final estimate: PPL = 9.9361 +/- 0.07125
(CUDA):Mean PPL(Q): 9.904247 ± 0.070767
BF16: Mean PPL(Q): 9.949033 ± 0.071570

The negative ppl score is odd but likely because of tuning and running ppl on the wiki2 dataset.

CUDA Final Ppl Delta vs BF16: 0.044786(Note, MXFP6xMXFP6 loss during MMQ vs MXFP6xQ8)
~~CUDA Final Ppl Delta vs CPU: 0.031853 ~~
Full KL Divergence was only done on CUDA for speed consideration. CUDA is outside scope of this PR but is the primary intended use target, so results shown for reference for what to expect. Q8 and Q6_K still are better in quality, but MXFP6 win on prefill speed (especially combined with NVFP4) and there is still much more room to improve both.

Q8 pp512: 16802, tg128: :200:
'Ppl:9.956220,mean:1.001236, Mean kld: 0.001959,max: 11.359487, top p: 97.554%, RMS Δp 1.299%`

MXFP6 pp512: 17822, tg128: 180:
Ppl:9.904247,mean:0.996009, Mean kld: 0.021651,max: 17.561033, top p: 92.635%, RMS Δp 4.026%

MXFP6 (188 MXFP6 layers, 13 NVFP4 ) pp512: 21166, tg128: 242
Ppl: 10.161346,mean: 1.021864 , Mean kld:0.045214,max: 21.575535, top p: 88.191%, RMS Δp 5.787%

NVFP4 (Same imatrix/tuner as above, plus 4over6): pp512: 21312, tg128: 245
Ppl:10.287908,mean:1.034592, Mean kld: 0.082396,max: 21.185059, top p: 86.785,RMS Δp 7.725%

Q6_K: pp512: 14341, tg128: 227
Ppl: 9.992565,mean: 1.004891, Mean kld: 0.005779, max: 10.447377, top p:96.218%, RMS Δp 2.194%

NVFP4 (Converted HF/ModelOpt): pp512: 20239, tg128: 224
Ppl:10.838460,mean:1.089958, Mean kld: 0.104422,max: 21.996130, top p: 85.184%,RMS Δp 8.815%

Q4_K: pp512: 16137, tg128: 283
Ppl:10.395488,mean:1.045411, Mean kld: 0.046251,max: 19.578995, top p: 90.415%,RMS Δp 5.772

As seen from above data, there is a lot of promise for how we can use and leverage MXFP6. There remains much undetermined or near infinite flexibility how to optimize MXFP6 alone or with NVFP4 or Q8 to balance speed vs quality. Let's stay focused for now, but I have already combined NVFP4/MXFP6 with MXFP8, and using FP8 activations with the .f4f6f8 mma can further improve speed.

Newest Full Qwen3.5-4B kld log:

Details
====== Perplexity statistics ======
Mean PPL(Q)                   :   9.998305 ±   0.071790
Mean PPL(base)                :   9.943929 ±   0.071459
Cor(ln(PPL(Q)), ln(PPL(base))):  99.73%
Mean ln(PPL(Q)/PPL(base))     :   0.005453 ±   0.000528
Mean PPL(Q)/PPL(base)         :   1.005468 ±   0.000531
Mean PPL(Q)-PPL(base)         :   0.054377 ±   0.005275

====== KL divergence statistics ======
Mean    KLD:   0.015374 ±   0.000300
Maximum KLD:  19.809378
99.9%   KLD:   0.629969
99.0%   KLD:   0.094831
95.0%   KLD:   0.037901
90.0%   KLD:   0.026535
Median  KLD:   0.008529
10.0%   KLD:   0.000474
 5.0%   KLD:   0.000135
 1.0%   KLD:   0.000017
 0.1%   KLD:   0.000001
Minimum KLD:  -0.000077

====== Token probability statistics ======
Mean    Δp: -0.174 ± 0.009 %
Maximum Δp: 94.303%
99.9%   Δp: 20.150%
99.0%   Δp:  8.865%
95.0%   Δp:  4.176%
90.0%   Δp:  2.423%
75.0%   Δp:  0.462%
Median  Δp: -0.003%
25.0%   Δp: -0.753%
10.0%   Δp: -3.020%
 5.0%   Δp: -4.931%
 1.0%   Δp: -9.745%
 0.1%   Δp: -24.106%
Minimum Δp: -99.902%
RMS Δp    :  3.437 ± 0.044 %
Same top p: 93.487 ± 0.064 %

Block design with Repack:

Details

The on disk GGUF layout is as follows:

// Both MXFP6_E2M3 and TBD MXFP6_E3M2 share the same block
#define QK_MXFP6 32
#define QK_MXFP6_PACKED_BYTES 24
typedef struct {
    uint8_t e;                                      // UE8M0 scale for one K32 block
    uint8_t qs[QK_MXFP6_PACKED_BYTES];              // packed 32x6-bit E2M3 codes
} block_mxfp6;

You might think this seems inefficient, especially for the 8 bit containers that Blackwell needs. But there are a few reasons it was kept this way. It is necessary to explain the CUDA block design now (outside the scope of this PR, but relevant since that is the intended use case) to explain.
~~The CUDA layout is dynamic and is decided at loadtime with a fast repack. It is using the same proposed NVFP4/Blackwell "repack-mma into Blackwell layout tiles" as shown elsewhere.
So for the gguf layout, block_mxfp6 is kept row major, and as small as possible. It can be repacked the same way on load to other backends and hardware designs in the future without needing to ever change the disk layout again; the CUDA layout can also be improved upon and changed over time as new hardware or superior optimizations come up with something better. ~~
~~For MXFP6, a 3 lane 416 byte tile was tested and determined to be the most optimal for now:

  1. It saves packing 0s into VRAM or on disk.
  2. It remains almost as fast as a 16B AoSoA 8-bit container would be.
  3. The repacked tile form is much faster than staging block_mxfp6 into tiles during inference
    Eg, for Qwen3.6-35A3B-A3B-MXFP6:
    A perfect AoSoA aligned 544-byte tile needs a 34609 MiB CUDA buffer. For a 32606 MiB card this offloads. On the 416-byte tile layout, this is 26,467 MiB/CUDA, 379 MiB/CPU. Testing shows staging to a tile would be 2.255ns/mma vs 1.485 ns/mma, so about 30% slower. It would add complexity but a future potential option would be to determine the VRAM at loading, then choosing which what tile layout to use depending on the model size.
    So the tile version is:~~
struct  __align__(16) block_mxfp6_e2m3_blackwell_frag {
    uint32_t regs[32][3];
    uint8_t  scales[32];
};

struct  __align__(16) block_mxfp6_e2m3_blackwell {
    block_mxfp6_e2m3_blackwell_frag tiles[QK_MXFP6_E2M3_FRAGS];
};

struct  __align__(16) block_mxfp6_e2m3_blackwell_tensor {
    float         weight_scale;
    float         input_scale;  // Needs comparison check for benefit but otherwise padded
    const float * weight_scales; // For MOE per expert
    const float * input_scales;
    block_mxfp6_e2m3_blackwell tiles[];
};

How will this work when the Blackwell MMA PTX needs an 8-bit container?
To get around this "2 bits taking up VRAM" with 0s, immediately at MMA we insert the 0s at that moment:~~

tx[0] = unpack4(w0);
tx[1] = unpack4((w0 >> 24) | (w1 << 8));
tx[2] = unpack4((w1 >> 16) | (w2 << 16));
tx[3] = unpack4(w2 >> 8);

So the required padding is added to registers just in time, but the rest of the time in VRAM is the 3 lane compact form. The optional tensor scales are applied as derived tensors directly into the vecdot, protecting precision and quality.

Detailed experimental results on layout choices and speed:

mxfp6_native_fp6_check packed=0x1f01283f roundtrip=7.500 status=ok
mxfp6_micro_result cc=120 iters=1 max_abs=0 status=ok

//current layout (with full llama.cpp production)
mxfp6_layout_bench variant=3lane416                      tile_bytes=416 ns_per_mma=1.996

//check bounds on every tile
mxfp6_layout_bench variant=3lane416_check_full           tile_bytes=416 ns_per_mma=2.005

//only check partial tails
mxfp6_layout_bench variant=3lane416_check_tail8          tile_bytes=416 ns_per_mma=1.769

//Padded stock nv_fp6 layouts in 8bit container
// this is faster but takes up more vram (theoretical)
mxfp6_layout_bench variant=aosoa640                          tile_bytes=640 ns_per_mma=1.844

//8 groups of 48 bytes as slab
mxfp6_layout_bench variant=slab48_416                        tile_bytes=416 ns_per_mma=2.036

// current layout (theoretical in microkernel only)
mxfp6_layout_bench variant=3lane416                           tile_bytes=416 ns_per_mma=1.485

// fully compact layout, MMA has to unpack/shuffle (theoretical min)
mxfp6_layout_bench variant=compact400                        tile_bytes=400 ns_per_mma=1.917

// Staging using smem
mxfp6_layout_bench variant=compact400_shared_lane416      tile_bytes=400 ns_per_mma=2.255

//MXFP6 for activations
mixed_layout_bench variant=mx6xmx6_3lane416              tile_bytes=416 ns_per_mma=1.919
mixed_layout_bench variant=mx6xmx6_prod_lane416_check_full   tile_bytes=416 ns_per_mma=2.529
mixed_layout_bench variant=mx6xmx6_prod_lane416_check_tail8  tile_bytes=416 ns_per_mma=1.752
mixed_layout_bench variant=mx6xmx6_slab48_416                tile_bytes=416 ns_per_mma=1.959
mixed_layout_bench variant=mx6xmx6_3lane416                   tile_bytes=416 ns_per_mma=2.127

//Reuse across 4 MMA ops, still trying to get this one working
mixed_layout_bench variant=mx6xmx6_lane416_reg_reuse4        tile_bytes=416 ns_per_mma=0.443

//Q8 Activations
mixed_layout_bench variant=mx6xq8_bw_vdr                     tile_bytes=452 ns_per_mma=2.893
mixed_layout_bench variant=mx6xq8_to_e4m3_mma_raw            tile_bytes=704 ns_per_mma=1.952
// Folding d into q8 conversion 
mixed_layout_bench variant=mx6xq8_to_e4m3_mma_fold_d         tile_bytes=704 ns_per_mma=1.696
// Usingn E8M0 scale from Q8
mixed_layout_bench variant=mx6xq8_to_e4m3_mma_e8scale_resid  tile_bytes=704 ns_per_mma=2.161
// Truncate the q8 scale
mixed_layout_bench variant=mx6xq8_to_e4m3_mma_trunc          tile_bytes=704 ns_per_mma=1.979
//Approximate with bitcast
mixed_layout_bench variant=mx6xq8_as_e4m3_mma_bitcast        tile_bytes=704 ns_per_mma=2.241

// tbd fp8 work
mixed_layout_bench variant=mx6xq8_to_e5m2_mma_raw            tile_bytes=704 ns_per_mma=2.138
mixed_layout_bench variant=mx6xq8_to_e5m2_mma_fold_d         tile_bytes=704 ns_per_mma=1.864
mixed_layout_bench variant=mx6xq8_to_e5m2_mma_e8scale      tile_bytes=704 ns_per_mma=2.002
mixed_layout_bench variant=mx6xq8_to_e5m2_mma_trunc          tile_bytes=704 ns_per_mma=2.510
mixed_layout_bench variant=mx6xq8_as_e5m2_mma_bitcast        tile_bytes=704 ns_per_mma=2.083

mixed_layout_bench variant=mx6xmx6_compact_load              tile_bytes=400 ns_per_mma=2.545
mixed_layout_bench variant=mx6xmx6_compact400_reg_reuse4     tile_bytes=400 ns_per_mma=0.394

mixed_layout_bench variant=nv4xnv4_over_k64                tile_bytes=2560 ns_per_mma=1.717
mixed_layout_bench variant=mx4xmx6_mxf8f6f4_k32              tile_bytes=544 ns_per_mma=1.834
mixed_layout_bench variant=mx6xmx4_mxf8f6f4_k32              tile_bytes=416 ns_per_mma=1.591

mixed_layout_bench variant=mixedk_100pct_mx6                  tile_bytes=416 ns_per_mma=0.424
mixed_layout_bench variant=mixedk_100pct_nv4                  tile_bytes=2560 ns_per_mma=1.064
mixed_layout_bench variant=mixedk_50nv4_50mx6                 tile_bytes=2976 ns_per_mma=0.800
mixed_layout_bench variant=mixedk_50nv4_25mx4mx6_25mx6        tile_bytes=3520 ns_per_mma=0.701
mixed_layout_bench variant=mixedk_50nv4_50_nv4mx6_pow2   tile_bytes=2560 ns_per_mma=0.842

Faster Tile Layout Now in GGUF, No Repack Needed

Details Refactored version now maintains the tile layout directly on disk. Repack is not necessary; the layout is already interleaved and quantized directly into tiles ready to go for MMA. The CPU version reads from the tile layout. Small adapter keeps this working with GGML functions expecting block/rows. Extensive testing thus far found this to be the fastest layout on CUDA and no traditional `load_tiles`, `load_ldmatrix`, `cpasync` etc is needed, the tile goes directly into registers. ``` define QK_MXFP6 64 #define QK_MXFP6_SUB 32 #define QK_MXFP6_PACKED_BYTES 24 #define MXFP6_TILE_ROWS 16 #define MXFP6_TILE_FRAGS 2 #define MXFP6_TILE_LANES 32 #define MXFP6_TILE_PAYLOADS 3 #define MXFP6_TILE_BYTES 832 #define MXFP6_ROW_BYTES (MXFP6_TILE_BYTES / MXFP6_TILE_ROWS)

typedef struct GGML_ALIGN(16) {
uint32_t lane[MXFP6_TILE_LANES][MXFP6_TILE_PAYLOADS];
uint8_t scale[MXFP6_TILE_LANES];
} tile_mxfp6_frag;

typedef struct GGML_ALIGN(16) {
tile_mxfp6_frag frag[MXFP6_TILE_FRAGS];
} tile_mxfp6;
struct ggml_tensor;

typedef struct {
const struct ggml_tensor * tensor;
const void * tile;
int64_t row;
int64_t channel;
} ggml_tile_to_row_ref;

typedef struct GGML_ALIGN(16) {
float weight_scale;
float input_scale;
const float * weight_scales;
const float * input_scales;
#if !defined(__cplusplus)
tile_mxfp6 tiles[];
#endif
} tensor_mxfp6;

</details>

## Passed Test Results

<details>

Test-backend-ops:

./test-backend-ops -p "mxfp6" -b CPU
... previous tests ...
ch_dims=[1,1]): OK
MUL_MAT_VEC_FUSION(type=mxfp6_e2m3,glu_op=1,m=1,n=32,k=256,use_id=1,n_mats=16,n_used=8,b=1,with_bias=1,with_gate=1,batch_dims=[4,2]): OK
MUL_MAT_VEC_FUSION(type=mxfp6_e2m3,glu_op=1,m=1,n=32,k=256,use_id=1,n_mats=16,n_used=8,b=1,with_bias=1,with_gate=1,batch_dims=[1,1]): OK
192/192 tests passed
Backend CPU: OK
2/2 backends passed
OK


Ci Results:

.Ran both ci-cpu and ci-cuda to check the model prior to posting.
100% tests passed, 0 tests failed out of 44
....etc...
real 0m35.022s
user 0m42.320s
sys 0m8.429s
Label Time Summary:
main = 23.61 sec*proc (42 tests)
..... PASS


</details>

AI Usage Disclosure:  Yes.  AI assistance helped create, optimize and analyze data collection over several weeks, it would likely have been impossible to do as one person (including the full Blackwell implementation).  Each section of code has been refactored/scrutinized carefully and edited by hand.

@github-actions github-actions Bot added testing Everything test related python python script changes ggml changes relating to the ggml tensor library for machine learning labels May 4, 2026
@jeffbolznv

Copy link
Copy Markdown
Contributor

I'd prefer a layout with a larger block size that doesn't require repacking.

@michaelw9999

This comment was marked as outdated.

@michaelw9999

Copy link
Copy Markdown
Contributor Author

Hi @jeffbolznv

Thanks for the suggestion, I went back at it and now have a slightly faster implementation, but without using any repack. It's still keeping my preloaded tile idea, which is the fastest, and we can put those right into the gguf and not really an issue. 50+ variations of CUDA trials running all day found best tile was just doubling the layout of the previous tile. (so now 832B). I've got CPU rewritten to use that directly now too. On CUDA, with the bigger block and some new tweaks, it gained about 5% for both pp512/tg128 Qwen3.5-4b. the model loads ~1.5s faster without needing the repack. About 18,200 pp512, 191tg. Qwen3.6-35A-A3B is about 7500/170.
I'll post the updated commit and upload the new GGUF on HF after clean up and build checks, and verifying ppl/kld runs all come back correct.

@Djip007

Djip007 commented May 5, 2026

Copy link
Copy Markdown
Contributor

I'd prefer a layout with a larger block size that doesn't require repacking.

I don't thing you can have high perf without repacking. And the block size is really related to hardware, tensor sizes.
So for me keep gguf as simple as possible, and let the backend do is optimised repacking.

@michaelw9999

Copy link
Copy Markdown
Contributor Author

@Djip007 interestingly enough, the newer "tile-in-gguf" layout on cpu is also faster:
16.5/pp512 and 6.0/tg128. The original was about half that. Was not really focusing primarily about CPU speed however.

@Djip007

Djip007 commented May 5, 2026

Copy link
Copy Markdown
Contributor

16.5/pp512 and 6.0/tg128

Didn't know what CPU you have, and if it is for the Qwen3.5-4b. but for me look really slow.

_mm512_dpwssd_epi32 can use dot2 on signed int16: so base bloc is A[M/M0][K/K0][M0=16][K0=2]
other use dot4 on int8 (mix signed/unsigned so more complicated...) bloc is A[M/M0][K/K0][M0=16][K0=4]

this are for AVX512...

@michaelw9999

Copy link
Copy Markdown
Contributor Author

Pushed a new refactored version. No more block layout and no more repack into tiles.
Now it is storing the fast tile layout directly in the gguf, and the CPU kernel consumes the tile version directly (it's faster, too). Re-quantized the model again and have better results, too:

Mean PPL(Q)                   :   9.998305 ±   0.071790
Mean PPL(base)                :   9.943929 ±   0.071459
Cor(ln(PPL(Q)), ln(PPL(base))):  99.73%
Mean ln(PPL(Q)/PPL(base))     :   0.005453 ±   0.000528
Mean PPL(Q)/PPL(base)         :   1.005468 ±   0.000531
Mean PPL(Q)-PPL(base)         :   0.054377 ±   0.005275

====== KL divergence statistics ======
Mean    KLD:   0.015374 ±   0.000300
Maximum KLD:  19.809378
Median  KLD:   0.008529

RMS Δp    :  3.437 ± 0.044 %
Same top p: 93.487 ± 0.064 %

The weight and input scale are included in the tile header as well, so there's no more code plumbing needed to load up separate tensors or deal with any derived tensor pathway. They are just easily factored in directly in the vecdot.

@Djip007

Djip007 commented May 7, 2026

Copy link
Copy Markdown
Contributor

nice!

did you manage to store the

    float         weight_scale;
    float         input_scale;
    const float * weight_scales;
    const float * input_scales;

in the gguf file? or have an idea for that?
at the begining/end of the tensor data?

(I think of what we can do for FP8 ... )

@michaelw9999

Copy link
Copy Markdown
Contributor Author

nice!

did you manage to store the

    float         weight_scale;
    float         input_scale;
    const float * weight_scales;
    const float * input_scales;

in the gguf file? or have an idea for that? at the begining/end of the tensor data?

(I think of what we can do for FP8 ... )

Yes
This is stored in the gguf file and is part of the outer tile. It is faster this way. Thephysical GGUF physical layout is a derived CUDA tiled serialization of the MXFP6 data.
The quantizer writes those header fields directly into the tensor payload, so I stopped it from emitting separate MXFP6 .scale or .input_scale tensors which just cause more code work.
Yes, this goes for MXFP8 too.. I already have MXFP8 and FP8 working; I've been doing the code in parallel, I just don't want to flood too much here so we do one thing at a time. MXFP6 activations are using FP8 with just a very small amount of code to bypass Q8 and get to f4f6f8 which is faster than MXFP6 x Q8. But let's do one thing at a time :)

@Djip007

Djip007 commented May 7, 2026

Copy link
Copy Markdown
Contributor

Yes, this goes for MXFP8 too

I thing at native E4M3 models, but can be good too. (I had some CPU/RDNA3 kernel for that, but did not finish (start) the gguf file. I "only" had some on load convertion from BF16 model to packed E4M3 / E3M4 , with different scaling schema.
I only did not find (understand) how the gguf data is construct/read in this PR.

But let's do one thing at a time :)

Yes! 👍

@michaelw9999

Copy link
Copy Markdown
Contributor Author

I only did not find (understand) how the gguf data is construct/read in this PR.

Hope this (generated image) helps..
image

@Djip007

Djip007 commented May 7, 2026

Copy link
Copy Markdown
Contributor

For the weight_scales / input_scales
You think to store it at the end of the tiles. ?

@michaelw9999

Copy link
Copy Markdown
Contributor Author

No, they actually are at the front/top, the header tensor_mxfp6 contains tiles[],

struct tensor_mxfp6 {
    float         weight_scale;
    float         input_scale;
    const float * weight_scales; <--- NULL pointer in file
    const float * input_scales; <-- NULL pointer  in file
    tile_mxfp6   tiles[];
};

tile_mxfp6 is the container type for these tiles[], underneath tensor_mxfp6:

tensor_mxfp6
   (tensor weight/input scales here)
   tile_mxfp6 tiles[]
      tile_mxfp6[0]
         frag[0]: tile_mxfp6_frag is the container type
         frag[1]
      tile_mxfp6[1]
      tile_mxfp6[2]   

One tile_mxfp6_frag:

   tile_mxfp6_frag:
      lane[32][3]  32 lanes,  each has 3 mxfp6 payloads (weights)
      scale[32]   32 MX scales (the subblock scales, E8M0)

@michaelw9999

michaelw9999 commented May 7, 2026

Copy link
Copy Markdown
Contributor Author

To clarify just a bit more about why this is here:

    const float * weight_scales; <--- NULL pointer in file
    const float * input_scales; <-- NULL pointer  in file

These are placeholders in the GGUF, but not actively used in the CPU version, so it appears it doesn't do anything now. They get filled on the CUDA version only. MUL_MAT_ID will select the channel_x expert scale from inside the kernel. That way it does weight_scales[channel_x] in the vecdot instead of launching separate get_rows + mul. That increases both speed and keeps the accuracy applying the weight scale without losing any precision (eg, as derived). CPU side is fine without patching because it keeps the existing separate scale nodes in the graph. On CUDA it will nullify the graph .scale node so they don't get late scaled and get applied directly:

  1. first attach expert scale tensor into weight->src[0]
  2. set to null the graph .scale node
  3. matmul_id gets correct expert channel id
  4. vecdot does weight_scales[channel_x] inside the kernel

@Djip007

Djip007 commented May 7, 2026

Copy link
Copy Markdown
Contributor

keeps the existing separate scale nodes in the graph

Ho. That what I do not like (and gg used not too) because this mean model graph depend on quantized type used (ie: all model have to be update).
So I'll look if we can add it at the end of the data(?). (with for exemple weight_scales / input_scales as delta position from data/tiles.

but it is more for #22042 or some other Discussions

1 point : be careful with this header on mmap file... I don't know if you can write it. (for CPU tensors or similare)

@michaelw9999

Copy link
Copy Markdown
Contributor Author

keeps the existing separate scale nodes in the graph

Ho. That what I do not like (and gg used not too) because this mean model graph depend on quantized type used (ie: all model have to be update). So I'll look if we can add it at the end of the data(?). (with for exemple weight_scales / input_scales as delta position from data/tiles.

but it is more for #22042 or some other Discussions

1 point : be careful with this header on mmap file... I don't know if you can write it. (for CPU tensors or similare)

We can certainly make the CPU side do the same as CUDA, it already handles the mmap issue and does much of the math (maybe as you think we could improve this further, as far as coming up with the best GGUF layout). Perhaps I need clean up my CUDA MXPF6 side and post that code to a fork for relevance, but this how I handle that there:

static bool ggml_cuda_set_tensor_mxfp6(ggml_tensor * tensor, const void * data, size_t offset, size_t size, int device) {

    ggml_tensor * storage = tensor->view_src ? tensor->view_src : tensor;
    const size_t logical_size = ggml_nbytes(storage);
    const size_t packed_size = ggml_cuda_mxfp6_e2m3_tensor_alloc_size(storage);
    GGML_ASSERT(logical_size == packed_size);
    const size_t storage_offset = tensor->view_src ? tensor->view_offs : 0;

    CUDA_CHECK(cudaMemcpyAsync((char *) storage->data + storage_offset + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
    if (storage_offset == 0 && offset == 0 && size >= MXFP6_HEADER_OFFSET) {
        tile_mxfp6 header;
        memcpy(&header, data, MXFP6_HEADER_OFFSET);
        ggml_cuda_mxfp6_e2m3_patch_tensor_header(storage, &header);
        CUDA_CHECK(cudaMemcpyAsync(storage->data, &header, MXFP6_HEADER_OFFSET, cudaMemcpyHostToDevice, cudaStreamPerThread));
    }
    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
    return true;
}
static inline void ggml_cuda_mxfp6_e2m3_patch_tensor_header(const ggml_tensor * tensor, tile_mxfp6_e2m3_blackwell_tensor * dst) {
    const ggml_tensor * weight_scale_t = tensor->src[0];
    const ggml_tensor * input_scale_t  = tensor->src[1];

    dst->weight_scales = weight_scale_t != nullptr && !ggml_is_scalar(weight_scale_t) &&
        weight_scale_t->data != nullptr && weight_scale_t->buffer != nullptr &&
        !ggml_backend_buffer_is_host(weight_scale_t->buffer) ? (const float *) weight_scale_t->data : nullptr;
    dst->input_scales = input_scale_t != nullptr && !ggml_is_scalar(input_scale_t) &&
        input_scale_t->data != nullptr && input_scale_t->buffer != nullptr &&
        !ggml_backend_buffer_is_host(input_scale_t->buffer) ? (const float *) input_scale_t->data : nullptr;
}
auto get_mapped_scalar_f32 = [&ml](ggml_tensor * tensor, float & value) {
    if (tensor == nullptr || !ggml_is_scalar(tensor) || tensor->type != GGML_TYPE_F32) {
        return false;
    }
    const auto * weight = ml.get_weight(ggml_get_name(tensor));
    if (weight == nullptr || weight->idx >= ml.mappings.size()) {
        return false;
    }
    const auto & mapping = ml.mappings.at(weight->idx);
    memcpy(&value, (const uint8_t *) mapping->addr() + weight->offs, sizeof(value));
    return true;
};

}

@Djip007

Djip007 commented May 7, 2026

Copy link
Copy Markdown
Contributor
    const ggml_tensor * weight_scale_t = tensor->src[0];
    const ggml_tensor * input_scale_t  = tensor->src[1];

Is the scale define as source of the weight tensor?
That may be for the ggml-runtime/backend the more interesting !
This will solve many question (like the type, sizes ...)

@michaelw9999

michaelw9999 commented May 7, 2026

Copy link
Copy Markdown
Contributor Author
    const ggml_tensor * weight_scale_t = tensor->src[0];
    const ggml_tensor * input_scale_t  = tensor->src[1];

Is the scale define as source of the weight tensor? That may be for the ggml-runtime/backend the more interesting ! This will solve many question (like the type, sizes ...)

Yes!

weight->src[0] = weight_scale;
weight->src[1] = input_scale;

and

const ggml_tensor * src0 = dst->src[0];        // weight tensor
const ggml_tensor * scale = src0->src[0];      // weight scale

This does not mess up any of the other models so there is no need to change anything for anything else (unless improvements for future types)

    auto attach_native_blackwell_scales = [](ggml_tensor * weight, ggml_tensor *& weight_scale, ggml_tensor * input_scale, bool use_cuda_native_scales) {
        if (weight == nullptr || (weight->type != GGML_TYPE_NVFP4 && weight->type != GGML_TYPE_MXFP6_E2M3)) {
            return;
        }

        // Files produced by llama-quantize store plain native blocks without an
        // auxiliary per-tensor scale. Treat that as an identity.
        if (weight_scale == nullptr) {
            const float identity_scale = 1.0f;
            memcpy(&weight->op_params[0], &identity_scale, sizeof(identity_scale));
        }

        weight->src[0] = weight_scale; // matmul reads attached scales before CUDA hides graph *_s
        weight->src[1] = input_scale;
        if (use_cuda_native_scales && weight_scale != nullptr) {
            weight_scale = nullptr; // CUDA applies attached scales inside matmul, so graph skips late ggml_mul
        }
    };

**correction, this was when using repack, after the new layout, newer cuda version doesnt need op_params for a placeholder anymore

@Djip007

Djip007 commented May 7, 2026

Copy link
Copy Markdown
Contributor

Yes!

weight->src[0] = weight_scale;
weight->src[1] = input_scale;

and

const ggml_tensor * src0 = dst->src[0];        // weight tensor
const ggml_tensor * scale = src0->src[0];      // weight scale

That's what I missed!
(https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L1279)

@michaelw9999 michaelw9999 force-pushed the mxfp6-cpu branch 5 times, most recently from 62dd50c to 86d1b6e Compare May 16, 2026 21:46
@michaelw9999

michaelw9999 commented May 16, 2026

Copy link
Copy Markdown
Contributor Author

Uploaded new MXFP6 models onto HF:
Qwen3.6-35B-A3B-MXFP6
Qwen3.6-27B-MXFP6

To demonstrate how the no-repack tile layout stored in the GGUF can work without big GGML changes, and keeping the scales in the same tile as the CPU version, I posted the CUDA implementation here (not for PR and not clean enough, but as POC):
mxfp6-cuda

This converts q8_1 tiles to FP8 for MMVQ, so doing native Blackwell MXFP6 x FP8 MMA. This makes MXFP6 almost 20% faster than NVFP4 for tg on this model while keeping the extra quality.

  Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes, VRAM: 32606 MiB
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| qwen35moe 35B.A3B MXFP6 - E2M3 |  26.46 GiB |    34.66 B | CUDA       |  99 |           pp512 |      8094.43 ± 49.53 |
| qwen35moe 35B.A3B MXFP6 - E2M3 |  26.46 GiB |    34.66 B | CUDA       |  99 |           tg128 |        188.10 ± 3.20 |

  Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes, VRAM: 32606 MiB
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| qwen35moe 35B.A3B NVFP4        |  21.48 GiB |    34.66 B | CUDA       |  99 |           pp512 |      8220.18 ± 57.89 |
| qwen35moe 35B.A3B NVFP4        |  21.48 GiB |    34.66 B | CUDA       |  99 |           tg128 |        159.53 ± 0.82 |

Hoping for any feedback on making this even better!

The ppl/kld delta on Qwen3.6-27b vs NVFP4 is really striking! MXFP4 vs BF16:

Mean PPL(Q)                   :   6.918624 ±   0.045510
Mean PPL(base)                :   6.900856 ±   0.045374
Cor(ln(PPL(Q)), ln(PPL(base))):  99.52%
Mean ln(PPL(Q)/PPL(base))     :   0.002571 ±   0.000645
Mean PPL(Q)/PPL(base)         :   1.002575 ±   0.000646
Mean PPL(Q)-PPL(base)         :   0.017768 ±   0.00445
Mean    KLD:   0.018925 ±   0.000647
RMS Δp    :  3.662 ± 0.060 %
Same top p: 95.023 ± 0.057 %

and for NVFP4:

====== Perplexity statistics ======
Mean PPL(Q)                   :   7.321749 ±   0.049291
Mean PPL(base)                :   6.900856 ±   0.045374
Cor(ln(PPL(Q)), ln(PPL(base))):  98.00%
Mean ln(PPL(Q)/PPL(base))     :   0.059204 ±   0.001341
Mean PPL(Q)/PPL(base)         :   1.060991 ±   0.001423
Mean PPL(Q)-PPL(base)         :   0.420893 ±   0.010245
Mean    KLD:   0.079692 ±   0.001052
RMS Δp    :  7.764 ± 0.063 %
Same top p: 89.020 ± 0.081 %

@michaelw9999 michaelw9999 force-pushed the mxfp6-cpu branch 2 times, most recently from b6f701a to 5789082 Compare May 25, 2026 22:55
@kominsoo

Copy link
Copy Markdown

Thanks for the excellent MXFP6 work. I was impressed by this PR and the related discussion, so I tried an experimental ROCm implementation
for RDNA4 with an AMD Radeon AI PRO R9700 32GB GPU, with help from Codex.

This is not PR-ready yet, but I wanted to share some early numbers in case they are useful.

Environment:

  • GPU: AMD Radeon AI PRO R9700, gfx1201, 32GB
  • Backend: ROCm / HIP
  • Model: Qwen3.6 27B GGUF variants
  • llama-bench: -ngl 99 -r 5
  • KV: f16/f16

Speed results:

quant size pp128 pp512 pp1024 tg128
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ ━━━━━━━ ━━━━━━━━━━━━ ━━━━━━━━━━━━ ━━━━━━━━━━━━ ━━━━━━━━━━━━━━
MXFP6_E2M3, compact/original layout ~21GB ~618 tok/s ~830 tok/s ~566 tok/s ~24.67 tok/s
─────────────────────────────────────── ─────── ──────────── ──────────── ──────────── ──────────────
MXFP6_E2M3, RDNA4 repacked F26 layout ~21GB ~680 tok/s ~910 tok/s ~589 tok/s ~23.72 tok/s
─────────────────────────────────────── ─────── ──────────── ──────────── ──────────── ──────────────
Q6_K ~21GB ~501 tok/s ~664 tok/s ~484 tok/s ~22.81 tok/s
─────────────────────────────────────── ─────── ──────────── ──────────── ──────────── ──────────────
UD-Q5_K_XL ~19GB ~492 tok/s ~773 tok/s ~509 tok/s ~24.15 tok/s
─────────────────────────────────────── ─────── ──────────── ──────────── ──────────── ──────────────
Q4_K_M ~16GB ~693 tok/s ~959 tok/s ~621 tok/s ~26.39 tok/s

Short WikiText-2 perplexity check, ctx 512, chunks 64:

quant PPL
━━━━━━━━━━━━ ━━━━━━━━━━━━━━━━━━━━
MXFP6_E2M3 6.5708 +/- 0.12733
──────────── ────────────────────
UD-Q5_K_XL 6.5689 +/- 0.12722

A few observations from the RDNA4 side:

  • MXFP6 is already faster than Q6_K for both PP and TG on this setup.
  • PP/GEMM benefits from an RDNA4-friendly repacked layout.
  • TG/matvec did not improve with the RDNA4 repack; the original compact layout was still faster for TG.
  • Against UD-Q5_K_XL, MXFP6 has similar short PPL and slightly better TG/PP, but the advantage is not large yet.
  • Q4_K_M remains very strong for TG and size.

My current impression is that a hardware-neutral GGUF MXFP6_E2M3 layout plus backend-specific load-time repacking is probably the right
direction. For RDNA4, the main open problem seems to be TG/matvec optimization rather than PP/GEMM.

I will keep experimenting before publishing code, but I wanted to share these early ROCm/RDNA4 data points.

@michaelw9999

Copy link
Copy Markdown
Contributor Author

Thanks for the excellent MXFP6 work. I was impressed by this PR and the related discussion, so I tried an experimental ROCm implementation for RDNA4 with an AMD Radeon AI PRO R9700 32GB GPU, with help from Codex.

This is not PR-ready yet, but I wanted to share some early numbers in case they are useful.

Environment:

  • GPU: AMD Radeon AI PRO R9700, gfx1201, 32GB
  • Backend: ROCm / HIP
  • Model: Qwen3.6 27B GGUF variants
  • llama-bench: -ngl 99 -r 5
  • KV: f16/f16

Speed results:

quant size pp128 pp512 pp1024 tg128 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ ━━━━━━━ ━━━━━━━━━━━━ ━━━━━━━━━━━━ ━━━━━━━━━━━━ ━━━━━━━━━━━━━━ MXFP6_E2M3, compact/original layout ~21GB ~618 tok/s ~830 tok/s ~566 tok/s ~24.67 tok/s ─────────────────────────────────────── ─────── ──────────── ──────────── ──────────── ────────────── MXFP6_E2M3, RDNA4 repacked F26 layout ~21GB ~680 tok/s ~910 tok/s ~589 tok/s ~23.72 tok/s ─────────────────────────────────────── ─────── ──────────── ──────────── ──────────── ────────────── Q6_K ~21GB ~501 tok/s ~664 tok/s ~484 tok/s ~22.81 tok/s ─────────────────────────────────────── ─────── ──────────── ──────────── ──────────── ────────────── UD-Q5_K_XL ~19GB ~492 tok/s ~773 tok/s ~509 tok/s ~24.15 tok/s ─────────────────────────────────────── ─────── ──────────── ──────────── ──────────── ────────────── Q4_K_M ~16GB ~693 tok/s ~959 tok/s ~621 tok/s ~26.39 tok/s

Short WikiText-2 perplexity check, ctx 512, chunks 64:

quant PPL ━━━━━━━━━━━━ ━━━━━━━━━━━━━━━━━━━━ MXFP6_E2M3 6.5708 +/- 0.12733 ──────────── ──────────────────── UD-Q5_K_XL 6.5689 +/- 0.12722

A few observations from the RDNA4 side:

  • MXFP6 is already faster than Q6_K for both PP and TG on this setup.
  • PP/GEMM benefits from an RDNA4-friendly repacked layout.
  • TG/matvec did not improve with the RDNA4 repack; the original compact layout was still faster for TG.
  • Against UD-Q5_K_XL, MXFP6 has similar short PPL and slightly better TG/PP, but the advantage is not large yet.
  • Q4_K_M remains very strong for TG and size.

My current impression is that a hardware-neutral GGUF MXFP6_E2M3 layout plus backend-specific load-time repacking is probably the right direction. For RDNA4, the main open problem seems to be TG/matvec optimization rather than PP/GEMM.

I will keep experimenting before publishing code, but I wanted to share these early ROCm/RDNA4 data points.

Thanks @kominsoo appreciate the kind words. A few questions:

  • Did you include the weight and input tensor scales? They aren't part of the MXFP6 standard but I still found them useful without really affecting speed.
  • Did you take a look at the cuda version here?
  • And, how did you implement TG? My MXFP6 trick for fast tg was to use FP8, but just with a fast conversion from Q8 and not put all the plumbing in. The other trick that got a lot of speedup was keeping everything 6-bit until the last second and inserting "Just In Time" 0s to get 8-bit when sending off to the GPU(required by CUDA). Keeping 0s slowed things down dragging them around otherwise. But I'm not sure how it works on ROCm.

I am actually working on a post now in the discussion forum to talk about the new NVFP4/MXFP6 quantizer with a new technique, you can check that in a bit. It needs the tensor scale to work best, you can give it a try with advanced-gguf-quantizer. I think the best final use case will be combining NVFP4 and MXFP6 to balance out each other as needed.

@kominsoo

Copy link
Copy Markdown

Thanks, those are exactly the points I need to check next. I did use your mxfp6-cuda branch as reference, but I want to verify whether my
ROCm path is applying the tensor scales correctly before commenting further. I’ll test that and report back once I have cleaner numbers.

@kominsoo

kominsoo commented May 30, 2026

Copy link
Copy Markdown

Following up now that I've verified actual generation output, not just bench numbers — and I owe a correction: the ~24.67 tok/s TG for the compact/original
layout in my earlier table was invalid due to an MMVQ addressing bug. After fixing it, the corrected and optimized comparative results are below.

Q: Tensor scales?
Yes. I preserve the 32-byte tensor_mxfp6 header (the weight_scale) and apply tensor_scale × per-block E8M0 inside both the vecdot and MMQ kernels. Dropping
this header originally produced garbage outputs — once preserved, the Wikitext-2 PPL successfully restored to normal/clean levels.

Q: Reference to your CUDA branch?
Yes. Your mxfp6-cuda branch served as the primary reference for the tile layout and scale handling — thank you very much for posting it.

Q: How was TG optimized?
This is where RDNA4 differs from your CUDA results.
* PP/GEMM uses real FP8 WMMA (wmma_f32_16x16x16_fp8_fp8): We convert MXFP6 to FP8 E4M3 at tile-load time, quantize activations to FP8 E4M3, and execute on
the RDNA4 FP8 matrix core. This resolves the PP regression and beats Q6_K by +21% to +29% across all prompt lengths.
* TG/matvec stays dp4a (int8): We found that routing batch-1 matvec through the FP8 WMMA unit is actually slower than dp4a on RDNA4 (due to tile padding
overhead). Instead, we optimized the dp4a path by hoisting the loop-invariant scales and accumulation chaining.
* As a result, the general MXFP6 GGUF is internally repacked on load time (auto-enabled when an RDNA4 device is present), achieving optimal serving speed on every
axis.

Corrected & Optimized Benchmark Numbers (Tested on AMD Radeon AI PRO R9700, gfx1201, ROCm 7.2, -ngl 99 -r 5)

| Quantization Format | PP128 | PP512 | PP1024 | TG128 | Wikitext-2 PPL (c=2048) |
| :--- | :---: | :---: | :---: | :---: | :---: |
| **MXFP6** *(Auto-repack + TG opt)* | **674.8** | **903.8** | **585.1** | **23.39** | **6.654** *(Lower is Better)* |
| **Q6_K** *(Baseline)* | 525.8 | 699.5 | 482.4 | 23.07 | 6.697 |

Note: TG serving speed practically doubles to **47.82 tok/s** (+104%) when utilising MTP self-speculation (`--spec-type draft-mtp --spec-draft-n-max 4`) on the `llama-server`.

@michaelw9999

Copy link
Copy Markdown
Contributor Author

Following up now that I've verified actual generation output, not just bench numbers — and I owe a correction: the ~24.67 tok/s TG for the compact/original layout in my earlier table was invalid due to an MMVQ addressing bug. After fixing it, the corrected and optimized comparative results are below.

Q: Tensor scales? Yes. I preserve the 32-byte tensor_mxfp6 header (the weight_scale) and apply tensor_scale × per-block E8M0 inside both the vecdot and MMQ kernels. Dropping this header originally produced garbage outputs — once preserved, the Wikitext-2 PPL successfully restored to normal/clean levels.

Q: Reference to your CUDA branch? Yes. Your mxfp6-cuda branch served as the primary reference for the tile layout and scale handling — thank you very much for posting it.

Q: How was TG optimized? This is where RDNA4 differs from your CUDA results. * PP/GEMM uses real FP8 WMMA (wmma_f32_16x16x16_fp8_fp8): We convert MXFP6 to FP8 E4M3 at tile-load time, quantize activations to FP8 E4M3, and execute on the RDNA4 FP8 matrix core. This resolves the PP regression and beats Q6_K by +21% to +29% across all prompt lengths. * TG/matvec stays dp4a (int8): We found that routing batch-1 matvec through the FP8 WMMA unit is actually slower than dp4a on RDNA4 (due to tile padding overhead). Instead, we optimized the dp4a path by hoisting the loop-invariant scales and accumulation chaining. * As a result, the general MXFP6 GGUF is internally repacked on load time (auto-enabled when an RDNA4 device is present), achieving optimal serving speed on every axis.

Corrected & Optimized Benchmark Numbers (Tested on AMD Radeon AI PRO R9700, gfx1201, ROCm 7.2, -ngl 99 -r 5)

| Quantization Format | PP128 | PP512 | PP1024 | TG128 | Wikitext-2 PPL (c=2048) |
| :--- | :---: | :---: | :---: | :---: | :---: |
| **MXFP6** *(Auto-repack + TG opt)* | **674.8** | **903.8** | **585.1** | **23.39** | **6.654** *(Lower is Better)* |
| **Q6_K** *(Baseline)* | 525.8 | 699.5 | 482.4 | 23.07 | 6.697 |

Note: TG serving speed practically doubles to **47.82 tok/s** (+104%) when utilising MTP self-speculation (`--spec-type draft-mtp --spec-draft-n-max 4`) on the `llama-server`.

That's great! Thanks for this data. I was minutes ago just now working on improving NVFP4 tg speeds on the repack version by trying FP8 as I had done with MXFP6, while you replied. I also found that dp4a Q8 was much better for single token than the tile path, the same as you were seeing here, but got around it by making a row shadow for MMVQ. Maybe that could help for ROCm too:

CUDA MXFP6 current FP8 TG path:
pp512 17427.2
tg128   230.64

CUDA MXFP6 dp4a TG experiment:
pp512 17344.8
tg128   171.67

NVFP4 forced FP8 TG1: ~120 tok/s
NVFP4 old repack Q8 TG1: ~170 tok/s
NVFP4 row-shadow Q8 TG1: ~193.8 tok/s

@michaelw9999 michaelw9999 force-pushed the mxfp6-cpu branch 2 times, most recently from dbdd3e3 to d6c4c96 Compare June 5, 2026 06:09
@michaelw9999 michaelw9999 requested a review from a team as a code owner June 5, 2026 06:09
@github-actions github-actions Bot added the Nvidia GPU Issues specific to Nvidia GPUs label Jun 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants