-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Expand file tree
/
Copy pathqwen3_5_mtp.py
More file actions
394 lines (340 loc) · 14.6 KB
/
qwen3_5_mtp.py
File metadata and controls
394 lines (340 loc) · 14.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3_5 MTP model."""
import logging
from contextlib import ExitStack
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_5 import Qwen3_5ForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_npu
logger = logging.getLogger(__name__)
class Qwen3_5ForCausalLMMTP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config=None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.is_multimodal = hasattr(config, "text_config")
if self.is_multimodal:
config = config.text_config
# The MTP model is unquantized in the nvfp4 checkpoint.
if quant_config and quant_config.get_name() == "modelopt_fp4":
quant_config = None
if (
is_npu()
and get_global_server_args().speculative_draft_model_quantization is None
):
quant_config = None
# Quark-quantized Qwen3.5 MXFP4 checkpoints ship the MTP module in
# bf16; every `mtp.*` layer appears under the quantization exclude
# list. Detect that and skip quantization here so linear/MoE weight
# loaders allocate bf16 shapes (see sgl-project/sglang#23113).
if quant_config and quant_config.get_name() == "quark":
exclude_layers = getattr(quant_config, "exclude_layers", [])
if any(
isinstance(layer, str) and layer.startswith("mtp.")
for layer in exclude_layers
):
quant_config = None
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.pp_group = get_pp_group()
self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
RMSNorm_cls = GemmaRMSNorm
self.pre_fc_norm_embedding = RMSNorm_cls(
config.hidden_size, config.rms_norm_eps
)
self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps)
config.num_hidden_layers = 1
config.full_attention_interval = 1
self.model = Qwen3_5ForCausalLM(
config,
quant_config,
prefix=add_prefix("mtp", prefix),
is_nextn=True,
)
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
@classmethod
def get_model_config_for_expert_location(cls, config):
text_config = getattr(config, "text_config", config)
return ModelConfigForExpertLocation(
num_layers=text_config.num_hidden_layers,
num_logical_experts=text_config.num_experts,
num_groups=None,
)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
if not self.config.tie_word_embeddings:
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
exit_stack = ExitStack()
if (
is_npu()
and self.quant_config is None
and get_global_server_args().quantization is not None
):
# ascend mtp unquant
exit_stack.enter_context(envs.SGLANG_DEEPEP_BF16_DISPATCH.override(True))
exit_stack.enter_context(
envs.DEEP_NORMAL_MODE_USE_INT8_QUANT.override(False)
)
assert input_embeds is None
input_embeds = forward_batch.mm_input_embeds
if (
forward_batch.forward_mode.is_extend()
and forward_batch.contains_mm_inputs()
and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
):
assert input_embeds is not None
input_embeds = torch.cat(
[input_embeds[:-1], self.model.embed_tokens(input_ids[-1].unsqueeze(0))]
)
if input_embeds is None:
input_embeds = self.model.embed_tokens(input_ids)
hidden_states = forward_batch.spec_info.hidden_states
if not forward_batch.forward_mode.is_idle():
input_embeds = self.pre_fc_norm_embedding(input_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([input_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
with get_global_expert_distribution_recorder().disable_this_region():
hidden_states = self.model(
input_ids,
positions,
forward_batch,
hidden_states,
)
exit_stack.close()
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for MoE experts (non-fused/fused)
num_experts = getattr(self.config, "num_experts", None)
if num_experts is not None:
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=num_experts,
)
else:
expert_params_mapping = []
# Skip loading extra parameters for GPTQ/modelopt models.
ignore_suffixes = (
".bias",
"_bias",
".k_scale",
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
"_input_scale",
)
# fused experts: experts.w13_weight / experts.w2_weight
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
def load_fused_expert_weights(
name: str,
params_dict: dict,
loaded_weight: torch.Tensor,
shard_id: str,
num_experts: int,
):
param = params_dict[name]
weight_loader = param.weight_loader
# Let EP MoE layer handle expert_ids that do not belong to local moe rank
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
)
return True
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Only process MTP branch weights
if "mtp" not in name:
continue
if name.startswith("mtp."):
# Remove the mtp. prefix for processing
name = name.replace("mtp.", "model.")
name = name.replace("model.fc", "fc")
name = name.replace("model.pre_fc", "pre_fc")
if ".self_attn." in name:
name = name.replace(".self_attn", "")
# 1) Process stacked parameters (q_proj/k_proj/v_proj & gate_proj/up_proj)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Check if this is a fused expert weight
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
# Skip non-matching weights
if weight_name not in name:
continue
# Skip MoE experts.* here, handled separately below
if "mlp.experts" in name:
continue
name_mapped = name.replace(weight_name, param_name)
# Skip loading extra parameters for GPTQ/modelopt models.
if (
name_mapped.endswith(ignore_suffixes)
and name_mapped not in params_dict
):
continue
if name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, shard_id)
name = name_mapped
break
else:
# 2) Process MoE expert weights (including fused experts)
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
# Fused experts: single checkpoint weight contains multiple experts
if is_fused_expert and num_experts is not None:
if "experts.gate_up_proj" in name:
# gate_up_proj fused: split into w1 / w3
loaded_w1, loaded_w3 = loaded_weight.chunk(2, dim=-2)
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_w1,
"w1",
num_experts,
)
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_w3,
"w3",
num_experts,
)
else:
# down_proj fused: distribute entire weight
load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight,
shard_id,
num_experts,
)
else:
# Non-fused expert, load by expert_id/shard
if (
name_mapped.endswith(ignore_suffixes)
and name_mapped not in params_dict
):
continue
if name_mapped not in params_dict:
break
param = params_dict[name_mapped]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
)
name = name_mapped
break
else:
# Skip expert weight if not handled by current rank
if is_expert_weight:
continue
# 3) Regular non-stacked / non-expert parameters, use default loader
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning_once(
f"Parameter {name} not found in params_dict, skip loading"
)
loaded_params.add(name)
return loaded_params
EntryClass = [Qwen3_5ForCausalLMMTP]