Skip to content

Commit 7229ade

Browse files
danbevggerganovmattjcly
committed
llama : add support for Nemotron 3 Super
This commit adds support for the Nemotron 3 Super model (120B.A12B) enabling this model to be converted to GGUF format and run in llama.cpp. Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Matt Clayton <156335168+mattjcly@users.noreply.github.com>
1 parent bd1ec81 commit 7229ade

File tree

11 files changed

+87
-12
lines changed

11 files changed

+87
-12
lines changed

convert_hf_to_gguf.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9743,14 +9743,21 @@ def __init__(self, *args, **kwargs):
97439743
# M: Mamba2, *: Attention, -: MLP
97449744
# MoE:
97459745
# M: Mamba2, *: Attention, E: Expert
9746-
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
9747-
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
9748-
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")]
9746+
pattern = self.hparams.get("hybrid_override_pattern") or self.hparams.get("layers_block_type")
9747+
if isinstance(pattern, str):
9748+
self._ssm_layers = [i for i, val in enumerate(pattern) if val == "M"]
9749+
self._mlp_layers = [i for i, val in enumerate(pattern) if val == ("E" if self.is_moe else "-")]
9750+
else:
9751+
self._ssm_layers = [i for i, val in enumerate(pattern) if val == "mamba"]
9752+
self._mlp_layers = [i for i, val in enumerate(pattern) if val == "moe"]
97499753

97509754
def get_attn_layers(self):
9751-
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
9752-
assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
9753-
return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
9755+
pattern = self.hparams.get("hybrid_override_pattern") or self.hparams.get("layers_block_type")
9756+
assert len(pattern) == self.block_count, f"Mismatch between pattern ({len(pattern)}) and block_count ({self.block_count})!"
9757+
if isinstance(pattern, str):
9758+
return [i for i, val in enumerate(pattern) if val == "*"]
9759+
9760+
return [i for i, val in enumerate(pattern) if val == "attention"]
97549761

97559762
def set_gguf_parameters(self):
97569763
super().set_gguf_parameters()
@@ -9784,6 +9791,9 @@ def set_gguf_parameters(self):
97849791
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
97859792
self.gguf_writer.add_expert_used_count(n_experts_used)
97869793

9794+
if (latent_size := self.hparams.get("moe_latent_size")) is not None:
9795+
self.gguf_writer.add_moe_latent_size(latent_size)
9796+
97879797
def set_vocab(self):
97889798
super().set_vocab()
97899799

@@ -9803,6 +9813,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
98039813
name = name[len("language_model."):]
98049814

98059815
if self.is_moe and bid is not None:
9816+
# Skip Multi-Token Prediction (MTP) tensors. These are used for
9817+
# for speculative decoding but we don't include them in this model
9818+
# conversion. See https://github.com/ggml-org/llama.cpp/pull/18886
9819+
if "mtp" in name:
9820+
print(f"Skipping MTP (Speculative) layer: {name}")
9821+
return []
9822+
98069823
if name.endswith("mixer.gate.e_score_correction_bias"):
98079824
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
98089825
yield from ModelBase.modify_tensors(self, data_torch, new_name, bid)

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9081,6 +9081,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
90819081
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
90829082
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
90839083
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
9084+
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
90849085

90859086
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
90869087
kernel void kernel_mul_mm_id(

gguf-py/gguf/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class LLM:
125125
EXPERT_GROUP_SCALE = "{arch}.expert_group_scale"
126126
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
127127
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
128+
MOE_LATENT_SIZE = "{arch}.moe_latent_size"
128129
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
129130
NUM_DEEPSTACK_LAYERS = "{arch}.n_deepstack_layers"
130131
POOLING_TYPE = "{arch}.pooling_type"
@@ -543,6 +544,8 @@ class MODEL_TENSOR(IntEnum):
543544
FFN_DOWN_CHEXP = auto()
544545
FFN_UP_CHEXP = auto()
545546
FFN_EXP_PROBS_B = auto()
547+
MOE_LATENT_DOWN = auto() # nemotron 3 super
548+
MOE_LATENT_UP = auto() # nemotron 3 super
546549
ATTN_Q_NORM = auto()
547550
ATTN_K_NORM = auto()
548551
LAYER_OUT_NORM = auto()
@@ -986,6 +989,8 @@ class MODEL_TENSOR(IntEnum):
986989
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
987990
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
988991
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
992+
MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super
993+
MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super
989994
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
990995
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
991996
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
@@ -2913,6 +2918,9 @@ class MODEL_TENSOR(IntEnum):
29132918
MODEL_TENSOR.FFN_GATE_INP,
29142919
MODEL_TENSOR.FFN_UP_EXP,
29152920
MODEL_TENSOR.FFN_DOWN_EXP,
2921+
# expert latent
2922+
MODEL_TENSOR.MOE_LATENT_DOWN,
2923+
MODEL_TENSOR.MOE_LATENT_UP,
29162924
# shared expert
29172925
MODEL_TENSOR.FFN_DOWN_SHEXP,
29182926
MODEL_TENSOR.FFN_UP_SHEXP,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,9 @@ def add_experts_per_group(self, count: int) -> None:
859859
def add_moe_every_n_layers(self, value: int) -> None:
860860
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
861861

862+
def add_moe_latent_size(self, value: int) -> None:
863+
self.add_uint32(Keys.LLM.MOE_LATENT_SIZE.format(arch=self.arch), value)
864+
862865
def add_nextn_predict_layers(self, count: int) -> None:
863866
self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
864867

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,14 @@ class TensorNameMap:
571571
"model.layers.{bid}.mlp.experts.gate_up_proj",
572572
),
573573

574+
MODEL_TENSOR.MOE_LATENT_DOWN: (
575+
"backbone.layers.{bid}.mixer.fc1_latent_proj", # nemotron 3 super
576+
),
577+
578+
MODEL_TENSOR.MOE_LATENT_UP: (
579+
"backbone.layers.{bid}.mixer.fc2_latent_proj", # nemotron 3 super
580+
),
581+
574582
# Feed-forward down
575583
MODEL_TENSOR.FFN_DOWN: (
576584
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox

src/llama-arch.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
185185
{ LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" },
186186
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
187187
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
188+
{ LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" },
188189
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
189190
{ LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" },
190191
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
@@ -365,6 +366,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
365366
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
366367
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
367368
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
369+
{ LLM_TENSOR_FFN_LATENT_DOWN, "blk.%d.ffn_latent_down" },
370+
{ LLM_TENSOR_FFN_LATENT_UP, "blk.%d.ffn_latent_up" },
368371
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
369372
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
370373
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
@@ -1879,6 +1882,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
18791882
LLM_TENSOR_FFN_UP_EXPS,
18801883
LLM_TENSOR_FFN_DOWN_EXPS,
18811884
LLM_TENSOR_FFN_EXP_PROBS_B,
1885+
LLM_TENSOR_FFN_LATENT_DOWN,
1886+
LLM_TENSOR_FFN_LATENT_UP,
18821887
// MoE shared expert layer
18831888
LLM_TENSOR_FFN_DOWN_SHEXP,
18841889
LLM_TENSOR_FFN_UP_SHEXP,
@@ -2754,6 +2759,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
27542759
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
27552760
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
27562761
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
2762+
// Nemotron 3 Super
2763+
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2764+
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
27572765
};
27582766

27592767
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ enum llm_kv {
189189
LLM_KV_EXPERT_GROUP_SCALE,
190190
LLM_KV_EXPERTS_PER_GROUP,
191191
LLM_KV_MOE_EVERY_N_LAYERS,
192+
LLM_KV_MOE_LATENT_SIZE,
192193
LLM_KV_NEXTN_PREDICT_LAYERS,
193194
LLM_KV_NUM_DEEPSTACK_LAYERS,
194195
LLM_KV_POOLING_TYPE,
@@ -385,6 +386,8 @@ enum llm_tensor {
385386
LLM_TENSOR_FFN_GATE_CHEXPS,
386387
LLM_TENSOR_FFN_UP_CHEXPS,
387388
LLM_TENSOR_FFN_EXP_PROBS_B,
389+
LLM_TENSOR_FFN_LATENT_DOWN,
390+
LLM_TENSOR_FFN_LATENT_UP,
388391
LLM_TENSOR_ATTN_Q_NORM,
389392
LLM_TENSOR_ATTN_K_NORM,
390393
LLM_TENSOR_LAYER_OUT_NORM,

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ struct llama_hparams {
8989
bool expert_weights_norm = false;
9090
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
9191
uint32_t moe_every_n_layers = 0;
92+
uint32_t moe_latent_size = 0;
9293
uint32_t nextn_predict_layers = 0;
9394

9495
float f_norm_eps;

src/llama-model.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ const char * llm_type_name(llm_type type) {
135135
case LLM_TYPE_100B_A6B: return "100B.A6B";
136136
case LLM_TYPE_102B_A12B: return "102B.A12B";
137137
case LLM_TYPE_106B_A12B: return "106B.A12B";
138+
case LLM_TYPE_120B_A12B: return "120B.A12B";
138139
case LLM_TYPE_122B_A10B: return "122B.A10B";
139140
case LLM_TYPE_196B_A11B: return "196B.A11B";
140141
case LLM_TYPE_230B_A10B: return "230B.A10B";
@@ -1861,10 +1862,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
18611862
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false);
18621863
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
18631864
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
1865+
ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false);
18641866

18651867
switch (hparams.n_layer) {
18661868
case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B
18671869
case 56: type = LLM_TYPE_9B; break;
1870+
case 88: type = LLM_TYPE_120B_A12B; break;
18681871
default: type = LLM_TYPE_UNKNOWN;
18691872
}
18701873
} break;
@@ -5544,6 +5547,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
55445547
const int64_t n_ssm_head = hparams.ssm_dt_rank;
55455548
const int64_t n_group = hparams.ssm_n_group;
55465549
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
5550+
const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd;
55475551

55485552
// embeddings
55495553
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -5603,8 +5607,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
56035607
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0);
56045608

56055609
// MoE branch
5606-
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
5607-
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
5610+
layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED);
5611+
layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED);
5612+
5613+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0);
5614+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0);
56085615

56095616
// Shared expert branch
56105617
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);

src/llama-model.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ enum llm_type {
126126
LLM_TYPE_100B_A6B,
127127
LLM_TYPE_102B_A12B, // Solar-Open
128128
LLM_TYPE_106B_A12B, // GLM-4.5-Air
129+
LLM_TYPE_120B_A12B, // Nemotron 3 Super
129130
LLM_TYPE_122B_A10B, // Qwen3.5
130131
LLM_TYPE_196B_A11B, // Step3.5-Flash
131132
LLM_TYPE_230B_A10B, // Minimax M2
@@ -294,6 +295,10 @@ struct llama_layer {
294295
struct ggml_tensor * ffn_up_exps_b = nullptr;
295296
struct ggml_tensor * ffn_gate_up_exps_b = nullptr;
296297

298+
// ff MoE latent proj
299+
struct ggml_tensor * ffn_latent_down = nullptr;
300+
struct ggml_tensor * ffn_latent_up = nullptr;
301+
297302
// ff shared expert (shexp)
298303
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
299304
struct ggml_tensor * ffn_gate_shexp = nullptr;

0 commit comments

Comments
 (0)