Skip to content

convert : apply Q/K RoPE permutation in NVFP4 repack path#22611

Merged
ggerganov merged 1 commit into
ggml-org:masterfrom
jmrobles:convert-nvfp4-qk-rope-permute
May 3, 2026
Merged

convert : apply Q/K RoPE permutation in NVFP4 repack path#22611
ggerganov merged 1 commit into
ggml-org:masterfrom
jmrobles:convert-nvfp4-qk-rope-permute

Conversation

@jmrobles

@jmrobles jmrobles commented May 2, 2026

Copy link
Copy Markdown
Contributor

Overview

NVFP4 GGUFs produced by convert_hf_to_gguf.py for Llama-architecture models output gibberish at inference (e.g. Certainlyrics|assistant|assistant|... for "The capital of France is"). Root cause: ModelBase._repack_nvfp4 writes weights directly and bypasses LlamaModel.modify_tensors, so the axis-0 row permutation that GGML's RoPE convention requires for q_proj.weight / k_proj.weight is never applied. Attention heads end up scrambled.

This PR mirrors LlamaModel.permute inside _repack_nvfp4 and applies it to both the nibble-packed weight tensor and the per-block scale tensor when the module is q_proj/k_proj and arch == "llama".

Reproducer (TinyLlama-1.1B)

Build Perplexity (wikitext-2 slice)
BF16 reference 44.0
NVFP4 GGUF, master (no fix) 4419 (gibberish)
NVFP4 GGUF, this PR 43.9

Also end-to-end verified on a 40B Llama-arch model (BSC ALIA-40b-instruct-2601, NVFP4-quantized via NVIDIA ModelOpt then converted with the patched script). Multilingual generation in Spanish/Catalan/Basque/Galician is coherent with the fix; gibberish without.

Why inline rather than calling LlamaModel.permute

_repack_nvfp4 lives on ModelBase; permute is a LlamaModel (subclass) staticmethod. Calling a subclass staticmethod from the parent is awkward. Extracting permute to a module-level helper would invite a larger refactor than this bug warrants. The duplication is 6 lines of math with a one-line comment referencing the canonical BF16 site.

Additional Information

Related: #19769 (which added GGML_TYPE_NVFP4 and the conversion path that this PR fixes).

Affects any Llama-architecture model (Llama 1/2/3, Mistral, Qwen2 derivatives that subclass LlamaModel, BSC ALIA, etc.) when quantized through the NVFP4 path. The BF16 path is unaffected.

Make sure to read the contributing guidelines before submitting a PR

  • I have read the contributing guidelines
  • Self-reported review complexity:
    • Low
    • Medium
    • High

@jmrobles jmrobles requested a review from CISC as a code owner May 2, 2026 09:53
@github-actions github-actions Bot added the python python script changes label May 2, 2026
@CISC

CISC commented May 2, 2026

Copy link
Copy Markdown
Member

You should override _repack_nvfp4 in LlamaModel instead, see how it is done in _LinearAttentionVReorderBase.

Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.
@jmrobles jmrobles force-pushed the convert-nvfp4-qk-rope-permute branch from d6322c5 to 332d229 Compare May 2, 2026 17:50
@jmrobles

jmrobles commented May 2, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for the steer @CISC. Refactored to override _repack_nvfp4 on LlamaModel following the _LinearAttentionVReorderBase pattern. Got two bonuses out of the move:

  • reusing the existing LlamaModel.permute static (instead of inlining the reshape/swap math like v1 did)
  • automatic opt-out for subclasses via the existing undo_permute flag, so the NVFP4 path now exactly mirrors the BF16 site at LlamaModel.modify_tensors

Force-pushed; PR diff is now +14 lines on LlamaModel, no ModelBase change.

@CISC CISC added the merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. label May 2, 2026
@ggerganov ggerganov merged commit db44417 into ggml-org:master May 3, 2026
6 checks passed
samuraieng pushed a commit to samuraieng/llama.cpp that referenced this pull request May 6, 2026
…2611)

Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.

Co-authored-by: Chema <chema@montevive.ai>
ljubomirj pushed a commit to ljubomirj/llama.cpp that referenced this pull request May 6, 2026
…2611)

Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.

Co-authored-by: Chema <chema@montevive.ai>
cetarthoriphros pushed a commit to cetarthoriphros/llama.cpp that referenced this pull request May 9, 2026
…2611)

Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.

Co-authored-by: Chema <chema@montevive.ai>
meh pushed a commit to meh/llama.cpp that referenced this pull request May 10, 2026
…2611)

Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.

Co-authored-by: Chema <chema@montevive.ai>
pwilkin added a commit to pwilkin/llama.cpp that referenced this pull request May 13, 2026
Ports 15 upstream commits (05e141a..5d44db6) that touched the
monolithic convert_hf_to_gguf.py into the new conversion/*.py layout
introduced by the refactor split.

New text/mmproj architectures registered:
  GraniteSpeechForConditionalGeneration, MiMoV2ForCausalLM,
  MiniCPMV4_6ForConditionalGeneration, Sarashina2VisionForCausalLM,
  SarvamMoEForCausalLM (+ modeling_sarvam_moe.SarvamMoEForCausalLM).

Notable changes:
- filter_tensors classmethod added to ModelBase/TextModel/MmprojModel
  and wired into index_tensors; many model classes refactored to move
  tensor-name skip/rename logic out of modify_tensors and into
  filter_tensors (upstream ggml-org#22597).
- LlamaModel._repack_nvfp4 override (Q/K RoPE permutation, ggml-org#22611).
- MistralModel yarn apply_scale support (ggml-org#22612).
- Gemma4Model._generate_nvfp4_tensors override for 26B NVFP4 (ggml-org#22804).
- LlavaVisionModel image-break token fallback for Mistral params.json
  -1 placeholders (ggml-org#22914).
- Pixtral 12B --mistral-format conversion fixes (ggml-org#22981).
- FP8 KV-cache scales fix (ggml-org#22818) and uint dtype byteswap disable
  (ggml-org#18908).

New files:
  conversion/sarashina2.py (Sarashina2VL text + vision)
baramofme pushed a commit to baramofme/llama-cpp-turboquant that referenced this pull request May 23, 2026
…2611)

Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.

Co-authored-by: Chema <chema@montevive.ai>
carlosfundora pushed a commit to carlosfundora/llama.cpp-1-bit-turbo that referenced this pull request May 24, 2026
…2611)

Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.

Co-authored-by: Chema <chema@montevive.ai>
(cherry picked from commit db44417)
winstonma pushed a commit to winstonma/llama.cpp that referenced this pull request May 27, 2026
…2611)

Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.

Co-authored-by: Chema <chema@montevive.ai>
fewtarius pushed a commit to fewtarius/llama.cpp that referenced this pull request May 30, 2026
…2611)

Llama-architecture q_proj/k_proj weights need an axis-0 row permutation
to match GGML's RoPE convention. The BF16 path applies this in
LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path
bypasses modify_tensors and writes weights directly through
ModelBase._repack_nvfp4. Without the permutation, attention heads end
up scrambled at inference and the model produces gibberish.

This change overrides _repack_nvfp4 on LlamaModel and applies the same
permutation to both the nibble-packed weight and the per-block scale
before delegating to ModelBase._repack_nvfp4 via super(). Reuses the
existing LlamaModel.permute static helper and respects the existing
undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.)
inherit the fix automatically.

Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419
(gibberish) to 43.9, matching the BF16-dequantized baseline (44.0).
Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama
architecture) with multilingual generation in Spanish/Catalan/Basque/
Galician all coherent with the fix applied.

Co-authored-by: Chema <chema@montevive.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants