Skip to content

llama-quant : fail early on missing imatrix, refactor type selection, code cleanup#19770

Merged
ggerganov merged 13 commits intoggml-org:masterfrom
ddh0:llama-quant-refactor-2
Mar 10, 2026
Merged

llama-quant : fail early on missing imatrix, refactor type selection, code cleanup#19770
ggerganov merged 13 commits intoggml-org:masterfrom
ddh0:llama-quant-refactor-2

Conversation

@ddh0
Copy link
Contributor

@ddh0 ddh0 commented Feb 20, 2026

Currently, if a quantization requires an importance matrix and one isn't provided, the program doesn't discover this until it reaches the offending tensor during the main quantization loop. Depending on model size and target type, this can mean wasting minutes to hours before the process aborts with a partial GGUF.

This PR adds a preliminary metadata pass over all tensors that determines target types upfront, enabling early validation. The quantization logic in llama-quant.cpp is refactored for clarity and correctness.


Fail early for missing required imatrix

A preliminary pass now computes each tensor's target quantization type before the main loop begins. If an importance matrix is required but missing, quantization fails immediately with an error identifying the offending tensor and its target type:

Screenshot 2026-02-26 at 1 05 57 AM

The old ftype-based imatrix guard in quantize.cpp is removed in favor of this per-tensor check. tensor_requires_imatrix (renamed from tensor_type_requires_imatrix) now uses a switch on dst_type and correctly exempts per_layer_token_embd.weight in addition to token_embd.weight.

Refactoring

Extracted functions to reduce the size of llama_model_quantize_impl and make the logic reusable across the preliminary and main passes, and to improve clarity and maintainability:

  • tensor_allows_quantization: consolidates all "should we quantize this tensor?" checks (norm tensors, RWKV weights, conv1d, positional embeddings, etc.) previously inlined in the main loop
  • tensor_category + tensor_get_category: replaces repeated name.find(...) calls with a single categorization pass; used by type selection and the attention-v check (category_is_attn_v)
  • llama_tensor_get_type / llama_tensor_get_type_impl: splits type resolution into a wrapper (manual overrides, token_embedding/output overrides, fallbacks) and the core mixture/architecture logic
  • tensor_type_fallback: extracted from the inline incompatible-shape handling block; now also handles the rare case where the fallback type itself is incompatible (falls back to F16 with a warning)
  • llama_ftype_get_default_type: the ftype -> ggml_type switch, extracted and organized by category
  • tensor_name_match_token_embd / tensor_name_match_output_weight: small helpers used across multiple call sites

Other changes

  • Regex patterns for --tensor-type compiled once in the quantize_state_impl constructor instead of compiled per-tensor, per-flag
  • tensor_quantization struct moved from llama-quant.cpp to the header (shared with quantize.cpp via #include instead of duplicated)
  • tensor_category enum and tensor_metadata struct added to the header
  • has_output replaced with has_tied_embeddings (clearer semantics, inverted logic, same behavior)
  • n_k_quantized counter removed (fallback count now reported against ml.n_tensors)
  • Removed dead MXFP4 sanity check code (#if 0 block)
  • Some logging cleanup

@ddh0 ddh0 changed the title quantize : refactor llama-quant.cpp quantize : refactor llama-quant.cpp (imatrix fail-early) Feb 21, 2026
@ddh0

This comment was marked as outdated.

@ddh0

This comment was marked as outdated.

@ddh0

This comment was marked as resolved.

@ddh0

This comment was marked as outdated.

@ddh0 ddh0 changed the title quantize : refactor llama-quant.cpp (imatrix fail-early) quantize : fail-early on missing imatrix; refactor + optimize Feb 26, 2026
@ddh0 ddh0 marked this pull request as ready for review February 26, 2026 05:57
@ddh0 ddh0 requested a review from ggerganov as a code owner February 26, 2026 05:57
@ddh0
Copy link
Contributor Author

ddh0 commented Feb 26, 2026

I am lucky to have had helpful feedback from @ubergarm, @AesSedai, @aldehir, @pwilkin, @compilade, and @bartowski - they all deserve a mention here. Thanks :) 🦙

@ddh0

This comment was marked as outdated.

@ddh0

This comment was marked as outdated.

@ddh0
Copy link
Contributor Author

ddh0 commented Feb 28, 2026

Marking back as draft until things have settled down a bit.

@ddh0 ddh0 marked this pull request as draft February 28, 2026 05:14
@ddh0 ddh0 changed the title quantize : fail-early on missing imatrix; refactor + optimize quantize : refactor; use quantization work scheduler for faster, more efficient quantization Mar 3, 2026
@ddh0 ddh0 force-pushed the llama-quant-refactor-2 branch from e314fa3 to decff8b Compare March 4, 2026 00:22
@ddh0 ddh0 changed the title quantize : refactor; use quantization work scheduler for faster, more efficient quantization quantize : imatrix fail-early, begin code cleanup Mar 4, 2026
@ddh0 ddh0 changed the title quantize : imatrix fail-early, begin code cleanup llama-quant : fail early on missing imatrix, refactor type selection, code cleanup Mar 4, 2026
@ddh0 ddh0 marked this pull request as ready for review March 4, 2026 01:37
@ddh0
Copy link
Contributor Author

ddh0 commented Mar 4, 2026

cc @ggerganov, @CISC - sorry for the messy thread, and for the re-marking as draft; there were previously a lot more changes here, but I'm going to leave some of my more aspirational improvements for future PRs.

I would like to get your thought on the changes here, as to whether or not they are acceptable overall.

@ddh0 ddh0 requested a review from compilade March 4, 2026 01:42
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the refactoring goals are good. Double-check that there aren't any functional changes by comparing a few quantizations before/after and we can merge.

#include "common.h"
#include "llama.h"
#include "gguf.h"
#include "../src/llama-quant.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Copy link
Contributor Author

@ddh0 ddh0 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though right now it's just for the struct tensor_quantization which was previously duplicated between quantize.cpp and llama-quant.cpp.

(edit: replying to ggerganov)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including internal sources from libllama is not OK, so better to keep the duplicated structs.

Generally, passing C++ structs across the C-style API is also not OK. This API has to be refactored at some point (#12511 (comment)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand, thanks. Will revert this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted in 4b7ebed.

@ddh0
Copy link
Contributor Author

ddh0 commented Mar 4, 2026

Some test results. Please let me know if there are more specific models / architectures that need to be checked.


Qwen3.5-35B-A3B bf16 -> Q8_0

Command:

rm output.gguf; time llama-quantize ~/gguf/Qwen3.5-35B-A3B-bf16.gguf output.gguf Q8_0 16

on master @ 24d2ee0

Result:

llama_model_quantize_impl: model size  = 66152.24 MiB (16.01 BPW)
llama_model_quantize_impl: quant size  = 35183.10 MiB (8.52 BPW)

main: quantize time = 144871.85 ms
main:    total time = 144871.85 ms

real    2m24.981s
user    3m16.528s
sys     0m56.611s

on llama-quant-refactor-2 @ 49fec40

Result:

llama_model_quantize_impl: model size  = 66152.24 MiB (16.01 BPW)
llama_model_quantize_impl: quant size  = 35183.10 MiB (8.52 BPW)

main: quantize time = 49671.08 ms
main:    total time = 49671.08 ms

real    0m49.713s
user    3m14.695s
sys     0m41.895s

Qwen3.5-35B-A3B bf16 -> Q4_K_M

Command:

rm output.gguf; time llama-quantize --imatrix ~/imatrices/Qwen_Qwen3.5-35B-A3B-imatrix.gguf ~/gguf/Qwen3.5-35B-A3B-bf16.gguf output.gguf Q4_K_M 16

on master @ 24d2ee0

Result:

llama_model_quantize_impl: model size  = 66152.24 MiB (16.01 BPW)
llama_model_quantize_impl: quant size  = 20177.96 MiB (4.88 BPW)

main: quantize time = 315551.57 ms
main:    total time = 315551.57 ms

real    5m15.832s
user    73m57.912s
sys     0m33.759s

on llama-quant-refactor-2 @ 49fec40

Result:

llama_model_quantize_impl: model size  = 66152.24 MiB (16.01 BPW)
llama_model_quantize_impl: quant size  = 20242.74 MiB (4.90 BPW)

main: quantize time = 311621.97 ms
main:    total time = 311621.97 ms

real    5m11.840s
user    73m12.041s
sys     0m33.219s

NOTE: This quant is 0.02 BPW larger than the same quant on master - this is due to the new category_is_attn_v checks.


L3.3 70B: bf16 -> Q8_0

Command:

rm output.gguf; time llama-quantize /media/T9/gguf/Llama-3.3-70B-Instruct-bf16.gguf output.gguf Q8_0 16

on master @ 24d2ee0

Result:

llama_model_quantize_impl: model size  = 134573.03 MiB (16.00 BPW)
llama_model_quantize_impl: quant size  = 71494.28 MiB (8.50 BPW)

main: quantize time = 384847.69 ms
main:    total time = 384847.69 ms

real    6m25.034s
user    7m30.412s
sys     1m32.364s

on llama-quant-refactor-2 @ 49fec40

Result:

llama_model_quantize_impl: model size  = 134573.03 MiB (16.00 BPW)
llama_model_quantize_impl: quant size  = 71494.28 MiB (8.50 BPW)

main: quantize time = 414279.31 ms
main:    total time = 414279.32 ms

real    6m54.481s
user    7m30.642s
sys     1m33.766s

@ddh0
Copy link
Contributor Author

ddh0 commented Mar 7, 2026

I am not sure how to interpret this failing CI: https://github.com/ggml-org/llama.cpp/actions/runs/22792516852/job/66121806962#step:7:1

@CISC
Copy link
Collaborator

CISC commented Mar 7, 2026

I am not sure how to interpret this failing CI: https://github.com/ggml-org/llama.cpp/actions/runs/22792516852/job/66121806962#step:7:1

ILLEGAL means the ccache corrupted itself and is generating trash executables, not related to your PR.

@ddh0
Copy link
Contributor Author

ddh0 commented Mar 8, 2026

@ggerganov Any chance this can be merged?

ddh0 added 2 commits March 8, 2026 01:18
it's in the preliminary loop now, so needs to be on its own line
@pwilkin pwilkin requested a review from CISC March 8, 2026 12:39
@ggerganov
Copy link
Member

Should we wait for the #20112 in order to be sure about the correctness or you think it's fine to merge? @bartowski1182 WDYT?

@bartowski1182
Copy link
Contributor

based on the changes I'm pretty confident that it won't affect anything, and my own PR may need some iterations if I get some review comments

for completeness I'll try to rebase my PR off of this fork (just locally) and run the tests myself to confirm, shouldn't take long

@bartowski1182
Copy link
Contributor

There are some changes in this versus master, but one is probably worth keeping

Notably, attn_qkv now gets caught by category_is_attn_v, so here:

https://github.com/ddh0/llama.cpp/blob/fb38b8cce94e6c70e1b96227effa6148fe58c7a0/src/llama-quant.cpp#L484

it will get upcast to a higher level, before it was left default.

The other issue is with token_embd.weight for gemma 3:

gemma-3-4b-it test results
=== gemma-3-4b-it ===
  FAIL  [F32] token_embd.weight                                  expected q6_K, got f32
  FAIL  [F16] token_embd.weight                                  expected q6_K, got f16
  FAIL  [Q4_0] token_embd.weight                                  expected q6_K, got q4_0
  FAIL  [Q4_1] token_embd.weight                                  expected q6_K, got q4_1
  FAIL  [Q5_0] token_embd.weight                                  expected q6_K, got q5_0
  FAIL  [Q5_1] token_embd.weight                                  expected q6_K, got q5_1
  FAIL  [Q2_K] token_embd.weight                                  expected q6_K, got q2_K
  FAIL  [Q3_K_S] token_embd.weight                                  expected q6_K, got q3_K
  FAIL  [Q3_K_M] token_embd.weight                                  expected q6_K, got q3_K
  FAIL  [Q3_K_L] token_embd.weight                                  expected q6_K, got q3_K
  FAIL  [Q4_K_S] token_embd.weight                                  expected q6_K, got q4_K
  FAIL  [Q4_K_M] token_embd.weight                                  expected q6_K, got q4_K
  FAIL  [Q5_K_S] token_embd.weight                                  expected q6_K, got q5_K
  FAIL  [Q5_K_M] token_embd.weight                                  expected q6_K, got q5_K
  FAIL  [IQ2_XXS] token_embd.weight                                  expected q5_K, got q2_K
  FAIL  [IQ2_XS] token_embd.weight                                  expected q5_K, got q2_K
  FAIL  [Q2_K_S] token_embd.weight                                  expected q6_K, got q2_K
  FAIL  [IQ3_XS] token_embd.weight                                  expected q6_K, got iq3_s
  FAIL  [IQ3_XXS] token_embd.weight                                  expected q5_K, got iq3_s
  FAIL  [IQ1_S] token_embd.weight                                  expected q5_K, got q2_K
  FAIL  [IQ4_NL] token_embd.weight                                  expected q6_K, got iq4_nl
  FAIL  [IQ3_S] token_embd.weight                                  expected q6_K, got iq3_s
  FAIL  [IQ3_M] token_embd.weight                                  expected q6_K, got iq3_s
  FAIL  [IQ2_S] token_embd.weight                                  expected q5_K, got iq3_s
  FAIL  [IQ2_M] token_embd.weight                                  expected q5_K, got iq3_s
  FAIL  [IQ4_XS] token_embd.weight                                  expected q6_K, got iq4_xs
  FAIL  [IQ1_M] token_embd.weight                                  expected q5_K, got q2_K
  FAIL  [BF16] token_embd.weight                                  expected q6_K, got bf16
  FAIL  [TQ1_0] token_embd.weight                                  expected q6_K, got q4_K
  FAIL  [TQ2_0] token_embd.weight                                  expected q6_K, got q4_K
  FAIL  gemma-3-4b-it: 3/33 ftype sections passed (239 tensors)

so it seems that we're no longer properly detecting the token_embd in gemma or something to that end? Seems unlikely that it should be a purposeful change, but I'll let @ddh0 comment

(ignore the bf16/f16/f32, I should probably update my script to not try to use those types since it'll be weird no matter what)

@ddh0
Copy link
Contributor Author

ddh0 commented Mar 9, 2026

so it seems that we're no longer properly detecting the token_embd in gemma or something to that end? Seems unlikely that it should be a purposeful change, but I'll let ddh0 comment

Yes, that is definitely wrong. Let me take a look...

@ddh0
Copy link
Contributor Author

ddh0 commented Mar 9, 2026

Should be fixed in f180055, if I'm not mistaken.

@bartowski1182
Copy link
Contributor

Looks better now when with that change, now only attn_qkv is different :)

@ggerganov ggerganov merged commit 1dab5f5 into ggml-org:master Mar 10, 2026
15 of 75 checks passed
@ddh0 ddh0 deleted the llama-quant-refactor-2 branch March 10, 2026 06:48
asyncd1spatch pushed a commit to asyncd1spatch/llama.cpp that referenced this pull request Mar 10, 2026
… code cleanup (ggml-org#19770)

* quantize : imatrix-fail early + code cleanup

* fix manual override printing

it's in the preliminary loop now, so needs to be on its own line

* revert header changes per ggerganov

* remove old #includes

* clarify naming

rename `tensor_quantization` to `tensor_typo_option` to descirbe its
functionality

* fix per barto
ddh0 added a commit to ddh0/llama.cpp that referenced this pull request Mar 10, 2026
In ggml-org#19770, I introduced a regression in the way the
`quantize_state_impl` counter values were initialized. I was
incrementing and using `n_attention_wv` in the same loop, when it should
have been fixed by the time we're deciding tensor types in
`llama_tensor_get_type_impl` (for `use_more_bits`).

I never observed a difference in any of [my
tests](ggml-org#19770 (comment))
- it was only after @bartowski kindly pointed this out that I realized
it was incorrect. (Thanks!)
ggerganov pushed a commit that referenced this pull request Mar 10, 2026
* llama-quant : correct `n_attention_wv` usage

In #19770, I introduced a regression in the way the
`quantize_state_impl` counter values were initialized. I was
incrementing and using `n_attention_wv` in the same loop, when it should
have been fixed by the time we're deciding tensor types in
`llama_tensor_get_type_impl` (for `use_more_bits`).

I never observed a difference in any of [my
tests](#19770 (comment))
- it was only after @bartowski kindly pointed this out that I realized
it was incorrect. (Thanks!)

* simplify
asyncd1spatch pushed a commit to asyncd1spatch/llama.cpp that referenced this pull request Mar 10, 2026
* llama-quant : correct `n_attention_wv` usage

In ggml-org#19770, I introduced a regression in the way the
`quantize_state_impl` counter values were initialized. I was
incrementing and using `n_attention_wv` in the same loop, when it should
have been fixed by the time we're deciding tensor types in
`llama_tensor_get_type_impl` (for `use_more_bits`).

I never observed a difference in any of [my
tests](ggml-org#19770 (comment))
- it was only after @bartowski kindly pointed this out that I realized
it was incorrect. (Thanks!)

* simplify
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants