-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Expand file tree
/
Copy pathserver_args.py
More file actions
7652 lines (7031 loc) · 324 KB
/
server_args.py
File metadata and controls
7652 lines (7031 loc) · 324 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The arguments of the server."""
from __future__ import annotations
import argparse
import dataclasses
import importlib
import importlib.util
import json
import logging
import os
import random
import tempfile
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from sglang.srt.configs.linear_attn_model_registry import get_linear_attn_spec_by_arch
from sglang.srt.connector import ConnectorType
from sglang.srt.environ import envs
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.layers.attention.fla.chunk_delta_h import CHUNK_SIZE as FLA_CHUNK_SIZE
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.utils.common import (
LORA_TARGET_ALL_MODULES,
SUPPORTED_LORA_TARGET_MODULES,
cpu_has_amx_support,
get_device,
get_device_memory_capacity,
get_device_name,
get_device_sm,
get_nvidia_driver_version,
get_quantization_config,
has_fp8_weights_in_checkpoint,
human_readable_int,
is_blackwell_supported,
is_cpu,
is_cuda,
is_flashinfer_available,
is_hip,
is_hopper_with_cuda_12_3,
is_host_cpu_arm64,
is_mps,
is_musa,
is_no_spec_infer_or_topk_one,
is_npu,
is_remote_url,
is_sm90_supported,
is_sm100_supported,
is_sm120_supported,
is_triton_kernels_available,
is_xpu,
json_list_type,
nullable_str,
parse_connector_type,
torch_release,
xpu_has_xmx_support,
)
from sglang.srt.utils.hf_transformers_utils import check_gguf_file
from sglang.srt.utils.network import NetworkAddress, get_free_port, wait_port_available
from sglang.srt.utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from sglang.srt.utils.tensor_bridge import use_mlx
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
# Define constants
DEFAULT_UVICORN_ACCESS_LOG_EXCLUDE_PREFIXES = ()
MIMO_V2_MODEL_ARCHS = (
"MiMoV2ForCausalLM",
"MiMoV2FlashForCausalLM",
)
LLAMA4_MODEL_ARCHS = (
"Llama4ForConditionalGeneration",
"Llama4ForCausalLM",
)
SAMPLING_BACKEND_CHOICES = {"flashinfer", "pytorch", "ascend"}
LOAD_FORMAT_CHOICES = [
"auto",
"pt",
"safetensors",
"npcache",
"dummy",
"sharded_state",
"gguf",
"bitsandbytes",
"mistral",
"layered",
"flash_rl",
"remote",
"remote_instance",
"fastsafetensors",
"private",
"runai_streamer",
]
QUANTIZATION_CHOICES = [
"awq",
"fp8",
"mxfp8",
"gptq",
"marlin",
"gptq_marlin",
"awq_marlin",
"bitsandbytes",
"gguf",
"modelopt",
"modelopt_fp8",
"modelopt_fp4",
"modelopt_mixed",
"petit_nvfp4",
"w8a8_int8",
"w8a8_fp8",
"moe_wna16",
"qoq",
"w4afp8",
"mxfp4",
"auto-round",
"compressed-tensors", # for Ktransformers
"modelslim", # for NPU
"quark", # AMD Quark quantizer (FP8 / MXFP4 / Int4FP8 etc.)
"quark_int4fp8_moe",
"unquant",
]
SPECULATIVE_DRAFT_MODEL_QUANTIZATION_CHOICES = QUANTIZATION_CHOICES
ATTENTION_BACKEND_CHOICES = [
# Common
"triton",
"torch_native",
"flex_attention",
"nsa",
"dsv4",
"compressed", # Deprecated alias for "dsv4"
# NVIDIA specific
"cutlass_mla",
"fa3",
"fa4",
"flashinfer",
"flashmla",
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
# AMD specific
"aiter",
"wave",
# Other platforms
"intel_amx",
"ascend",
"intel_xpu",
]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND = ["fa3", "triton"]
DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake", "mori"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
# Placeholder token inserted between items in Multi-Item Scoring sequences:
# query<delim>item1<delim>item2<delim>... Positions are pre-computed from item
# lengths (multi_item_delimiter_indices); the token only exists for FlashInfer
# attention mask compat and logprob column indexing. Will be removed once the
# attention backend supports position-only MIS.
MIS_DELIMITER_TOKEN_ID = 9999
MOE_RUNNER_BACKEND_CHOICES = [
"auto",
"deep_gemm",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_trtllm_routed",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
"cutlass",
"aiter",
"marlin",
]
MOE_A2A_BACKEND_CHOICES = [
"none",
"deepep",
"mooncake",
"nixl",
"mori",
"ascend_fuseep",
"flashinfer",
]
FP8_GEMM_RUNNER_BACKEND_CHOICES = [
"auto",
"deep_gemm",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_deepgemm",
"cutlass",
"triton",
"aiter",
]
FP4_GEMM_RUNNER_BACKEND_CHOICES = [
"auto",
"cutlass",
"flashinfer_cudnn",
"flashinfer_cutedsl",
"flashinfer_cutlass",
"flashinfer_trtllm",
]
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu", "slru", "priority"]
RL_ON_POLICY_TARGET_CHOICES = ["fsdp"]
LORA_BACKEND_CHOICES = ["triton", "csgmv", "ascend", "torch_native"]
ENCODER_TRANSFER_BACKEND_CHOICES = ["zmq_to_scheduler", "zmq_to_tokenizer", "mooncake"]
NSA_PREFILL_CP_SPLIT_CHOICES = ["in-seq-split", "round-robin-split"]
PREFILL_CP_SPLIT_CHOICES = ["in-seq-split"]
DEFAULT_LORA_EVICTION_POLICY = "lru"
NSA_CHOICES = [
"flashmla_sparse",
"flashmla_kv",
"flashmla_auto",
"fa3",
"tilelang",
"aiter",
"trtllm",
]
MAMBA_SCHEDULER_STRATEGY_CHOICES = ["auto", "no_buffer", "extra_buffer"]
MAMBA_BACKEND_CHOICES = ["triton", "flashinfer"]
LINEAR_ATTN_KERNEL_BACKEND_CHOICES = ["triton", "cutedsl", "flashinfer"]
# Allow external code to add more choices
def add_load_format_choices(choices):
LOAD_FORMAT_CHOICES.extend(choices)
def add_quantization_method_choices(choices):
QUANTIZATION_CHOICES.extend(choices)
def add_attention_backend_choices(choices):
ATTENTION_BACKEND_CHOICES.extend(choices)
def add_deterministic_attention_backend_choices(choices):
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
def add_radix_supported_deterministic_attention_backend_choices(choices):
RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND.extend(choices)
def add_disagg_transfer_backend_choices(choices):
DISAGG_TRANSFER_BACKEND_CHOICES.extend(choices)
def add_grammar_backend_choices(choices):
GRAMMAR_BACKEND_CHOICES.extend(choices)
def add_moe_runner_backend_choices(choices):
MOE_RUNNER_BACKEND_CHOICES.extend(choices)
def add_fp8_gemm_runner_backend_choices(choices):
FP8_GEMM_RUNNER_BACKEND_CHOICES.extend(choices)
def add_fp4_gemm_runner_backend_choices(choices):
FP4_GEMM_RUNNER_BACKEND_CHOICES.extend(choices)
def add_radix_eviction_policy_choices(choices):
RADIX_EVICTION_POLICY_CHOICES.extend(choices)
def add_rl_on_policy_target_choices(choices):
RL_ON_POLICY_TARGET_CHOICES.extend(choices)
def _resolve_speculative_algorithm_alias(
speculative_algorithm: Optional[str],
speculative_draft_model_path: Optional[str],
trust_remote_code: bool = False,
) -> Optional[str]:
"""Resolve CLI speculative algorithm; NEXTN/EAGLE may become FROZEN_KV_MTP for Gemma4 assistant drafts."""
is_gemma4_draft = False
if speculative_draft_model_path:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(
speculative_draft_model_path, trust_remote_code=trust_remote_code
)
is_gemma4_draft = "Gemma4AssistantForCausalLM" in (
getattr(cfg, "architectures", None) or []
)
if speculative_algorithm == "EAGLE3" and is_gemma4_draft:
raise ValueError(
"Gemma4AssistantForCausalLM draft requires "
"--speculative-algorithm NEXTN or EAGLE; EAGLE3 is "
"not supported for this draft architecture."
)
if speculative_algorithm == "NEXTN" or speculative_algorithm == "EAGLE":
if is_gemma4_draft:
logger.info(
"Detected Gemma4AssistantForCausalLM draft; "
f"promoting --speculative-algorithm {speculative_algorithm} to FROZEN_KV_MTP."
)
return "FROZEN_KV_MTP"
return "EAGLE"
return speculative_algorithm
@dataclasses.dataclass
class ServerArgs:
"""
The arguments of the server.
NOTE: When you add new arguments, please make sure the order
in this class definition the same as the order in the function
`ServerArgs.add_cli_args`.
Please follow the existing style to group the new arguments into related groups or create new groups.
"""
# Model and tokenizer
model_path: str
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
tokenizer_backend: str = "huggingface"
tokenizer_worker_num: int = 1
skip_tokenizer_init: bool = False
load_format: str = "auto"
model_loader_extra_config: str = "{}"
trust_remote_code: bool = False
context_length: Optional[int] = None
is_embedding: bool = False
enable_multimodal: Optional[bool] = None
revision: Optional[str] = None
model_impl: str = "auto"
# HTTP server
host: str = "127.0.0.1"
port: int = 30000
fastapi_root_path: str = ""
grpc_mode: bool = False
skip_server_warmup: bool = False
warmups: Optional[str] = None
nccl_port: Optional[int] = None
checkpoint_engine_wait_weights_before_ready: bool = False
# SSL/TLS
ssl_keyfile: Optional[str] = None
ssl_certfile: Optional[str] = None
ssl_ca_certs: Optional[str] = None
ssl_keyfile_password: Optional[str] = None
enable_ssl_refresh: bool = False
enable_http2: bool = False
# Quantization and data type
dtype: str = "auto"
quantization: Optional[str] = None
quantization_param_path: Optional[str] = None
kv_cache_dtype: str = "auto"
enable_fp32_lm_head: bool = False
modelopt_quant: Optional[Union[str, Dict]] = None
modelopt_checkpoint_restore_path: Optional[str] = None
modelopt_checkpoint_save_path: Optional[str] = None
modelopt_export_path: Optional[str] = None
quantize_and_serve: bool = False
rl_quant_profile: Optional[str] = None # For flash_rl load format
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_queued_requests: Optional[int] = None
max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None
enable_dynamic_chunking: bool = False
max_prefill_tokens: int = 16384
prefill_max_requests: Optional[int] = None
schedule_policy: str = "fcfs"
enable_priority_scheduling: bool = False
disable_priority_preemption: bool = False
default_priority_value: Optional[int] = None
abort_on_priority_when_disabled: bool = False
schedule_low_priority_values_first: bool = False
priority_scheduling_preemption_threshold: int = 10
schedule_conservativeness: float = 1.0
page_size: Optional[int] = None
swa_full_tokens_ratio: float = 0.8
disable_hybrid_swa_memory: bool = False
radix_eviction_policy: str = "lru"
enable_prefill_delayer: bool = False
prefill_delayer_max_delay_passes: int = 30
prefill_delayer_token_usage_low_watermark: Optional[float] = None
prefill_delayer_forward_passes_buckets: Optional[List[float]] = None
prefill_delayer_wait_seconds_buckets: Optional[List[float]] = None
prefill_delayer_queue_min_ratio: Optional[float] = None
prefill_delayer_max_delay_ms: Optional[float] = None
# Runtime options
device: Optional[str] = None
tp_size: int = 1
pp_size: int = 1
pp_max_micro_batch_size: Optional[int] = None
pp_async_batch_depth: int = 0
stream_interval: int = 1
batch_notify_size: int = 16
stream_response_default_include_usage: bool = False
incremental_streaming_output: bool = False
enable_streaming_session: bool = False
random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None
constrained_json_disable_any_whitespace: bool = False
watchdog_timeout: float = 300
soft_watchdog_timeout: Optional[float] = None
dist_timeout: Optional[int] = None # timeout for torch.distributed
download_dir: Optional[str] = None
model_checksum: Optional[str] = None
base_gpu_id: int = 0
gpu_id_step: int = 1
sleep_on_idle: bool = False
use_ray: bool = False
custom_sigquit_handler: Optional[Callable] = None
# Logging
log_level: str = "info"
log_level_http: Optional[str] = None
log_requests: bool = False
log_requests_level: int = 2
log_requests_format: str = "text"
log_requests_target: Optional[List[str]] = None
uvicorn_access_log_exclude_prefixes: List[str] = dataclasses.field(
default_factory=lambda: list(DEFAULT_UVICORN_ACCESS_LOG_EXCLUDE_PREFIXES)
)
crash_dump_folder: Optional[str] = None
show_time_cost: bool = False
enable_metrics: bool = False
grpc_http_sidecar_port: Optional[int] = None
enable_mfu_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False
tokenizer_metrics_custom_labels_header: str = "x-custom-labels"
tokenizer_metrics_allowed_custom_labels: Optional[List[str]] = None
extra_metric_labels: Optional[Dict[str, str]] = None
bucket_time_to_first_token: Optional[List[float]] = None
bucket_inter_token_latency: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None
prompt_tokens_buckets: Optional[List[str]] = None
generation_tokens_buckets: Optional[List[str]] = None
gc_warning_threshold_secs: float = 0.0
decode_log_interval: int = 40
enable_request_time_stats_logging: bool = False
kv_events_config: Optional[str] = None
enable_trace: bool = False
otlp_traces_endpoint: str = "localhost:4317"
# RequestMetricsExporter configuration
export_metrics_to_file: bool = False
export_metrics_to_file_dir: Optional[str] = None
# API related
api_key: Optional[str] = None
admin_api_key: Optional[str] = None
served_model_name: Optional[str] = None
weight_version: str = "default"
chat_template: Optional[str] = None
hf_chat_template_name: Optional[str] = None
completion_template: Optional[str] = None
file_storage_path: str = "sglang_storage"
enable_cache_report: bool = False
reasoning_parser: Optional[str] = None
strip_thinking_cache: bool = False
enable_strict_thinking: bool = False
tool_call_parser: Optional[str] = None
tool_server: Optional[str] = None
sampling_defaults: str = "model"
# Data parallelism
dp_size: int = 1
load_balance_method: str = "auto"
attn_cp_size: int = 1
moe_dp_size: int = 1
# Multi-node distributed serving
dist_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: int = 0
# Model override args in JSON
json_model_override_args: str = "{}"
preferred_sampling_params: Optional[str] = None
# LoRA
enable_lora: Optional[bool] = None
enable_lora_overlap_loading: Optional[bool] = None
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[
Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]]
] = None
max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8
lora_eviction_policy: str = "lru"
lora_backend: str = "csgmv"
max_lora_chunk_size: Optional[int] = 16
experts_shared_outer_loras: Optional[bool] = None
lora_use_virtual_experts: bool = False
lora_strict_loading: bool = False
lora_drain_wait_threshold: float = 0.0
# Kernel backend
attention_backend: Optional[str] = None
decode_attention_backend: Optional[str] = None
prefill_attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None
fp8_gemm_runner_backend: str = "auto"
fp4_gemm_runner_backend: str = "auto"
nsa_prefill_backend: Optional[str] = (
None # None = auto-detect based on hardware/kv_cache_dtype
)
nsa_decode_backend: Optional[str] = (
None # auto-detect based on hardware/kv_cache_dtype
)
disable_flashinfer_autotune: bool = False
mamba_backend: str = "triton"
# Speculative decoding
speculative_algorithm: Optional[str] = None
speculative_draft_model_path: Optional[str] = None
speculative_draft_model_revision: Optional[str] = None
speculative_draft_load_format: Optional[str] = None
speculative_num_steps: Optional[int] = None
speculative_eagle_topk: Optional[int] = None
speculative_num_draft_tokens: Optional[int] = None
speculative_dflash_block_size: Optional[int] = None
speculative_dflash_draft_window_size: Optional[int] = None
speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill"
speculative_draft_attention_backend: Optional[str] = None
speculative_moe_runner_backend: Optional[str] = None
speculative_moe_a2a_backend: Optional[str] = None
speculative_draft_model_quantization: Optional[str] = None
speculative_adaptive: bool = False
speculative_adaptive_config: Optional[str] = None
speculative_skip_dp_mlp_sync: bool = False
# Speculative decoding (ngram)
speculative_ngram_min_bfs_breadth: int = 1
speculative_ngram_max_bfs_breadth: int = 10
speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS"
speculative_ngram_max_trie_depth: int = 18
speculative_ngram_capacity: int = 10 * 1000 * 1000
speculative_ngram_external_corpus_path: Optional[str] = None
speculative_ngram_external_sam_budget: int = 0
speculative_ngram_external_corpus_max_tokens: int = 10000000
enable_multi_layer_eagle: bool = False
# Expert parallelism
ep_size: int = 1
moe_a2a_backend: Literal[
"none", "deepep", "mooncake", "nixl", "mori", "ascend_fuseep", "flashinfer"
] = "none"
moe_runner_backend: str = "auto"
record_nolora_graph: bool = True
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
enable_flashinfer_allreduce_fusion: bool = False
enforce_disable_flashinfer_allreduce_fusion: bool = False
enable_aiter_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
init_expert_location: str = "trivial"
enable_eplb: bool = False
eplb_algorithm: str = "auto"
eplb_rebalance_num_iterations: int = 1000
eplb_rebalance_layers_per_chunk: Optional[int] = None
eplb_min_rebalancing_utilization_threshold: float = 1.0
expert_distribution_recorder_mode: Optional[
Literal["stat", "stat_approx", "per_pass", "per_token"]
] = None
expert_distribution_recorder_buffer_size: Optional[int] = None
enable_expert_distribution_metrics: bool = False
deepep_config: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
elastic_ep_backend: Literal[None, "mooncake", "nixl"] = None
enable_elastic_expert_backup: bool = False
mooncake_ib_device: Optional[str] = None
elastic_ep_rejoin: bool = False
# Mamba cache
max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: Optional[str] = None
mamba_full_memory_ratio: float = 0.9
mamba_scheduler_strategy: str = "auto"
mamba_track_interval: int = 256
linear_attn_backend: str = "triton"
linear_attn_decode_backend: Optional[str] = None
linear_attn_prefill_backend: Optional[str] = None
# Hierarchical cache
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
hicache_size: int = 0
hicache_write_policy: str = "write_through"
hicache_io_backend: str = "kernel"
hicache_mem_layout: str = "layer_first"
hicache_storage_backend: Optional[str] = None
hicache_storage_prefetch_policy: str = "best_effort"
hicache_storage_backend_extra_config: Optional[str] = None
# Hierarchical sparse attention
enable_hisparse: bool = False
hisparse_config: Optional[str] = None
# LMCache
enable_lmcache: bool = False
# Ktransformers/AMX expert parallelism
kt_weight_path: Optional[str] = None
kt_method: Optional[str] = None
kt_cpuinfer: Optional[int] = None
kt_threadpool_count: Optional[int] = None
kt_num_gpu_experts: Optional[int] = None
kt_max_deferred_experts_per_token: Optional[int] = None
# Diffusion LLM
dllm_algorithm: Optional[str] = None
dllm_algorithm_config: Optional[str] = None
# Offloading
cpu_offload_gb: int = 0
offload_group_size: int = -1
offload_num_in_group: int = 1
offload_prefetch_step: int = 1
offload_mode: str = "cpu"
# Scoring configuration
# Enable Multi-Item Scoring optimization. Combines query and multiple items
# into a single sequence for efficient batch processing. Item boundaries are
# determined by pre-computed delimiter indices (from item lengths), not by the
# placeholder token. See MIS_DELIMITER_TOKEN_ID for details.
enable_mis: bool = False
# Optimization/debug options
disable_radix_cache: bool = False
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
enable_breakable_cuda_graph: bool = False
enable_profile_cuda_graph: bool = False
enable_cudagraph_gc: bool = False
debug_cuda_graph: bool = False
enable_layerwise_nvtx_marker: bool = False
enable_nccl_nvls: bool = False
enable_symm_mem: bool = False
disable_flashinfer_cutlass_moe_fp4_allgather: bool = False
enable_tokenizer_batch_encode: bool = False
disable_tokenizer_batch_decode: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mscclpp: bool = False
enable_torch_symm_mem: bool = False
pre_warm_nccl: bool = dataclasses.field(
default_factory=lambda: is_hip()
) # Pre-warm NCCL/RCCL to reduce P99 TTFT cold-start latency (default: True for AMD/HIP, False for others)
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_dp_attention_local_control_broadcast: bool = False
enable_dp_lm_head: bool = False
enable_two_batch_overlap: bool = False
enable_single_batch_overlap: bool = False
tbo_token_distribution_threshold: float = 0.48
enable_torch_compile: bool = False
disable_piecewise_cuda_graph: bool = False
enforce_piecewise_cuda_graph: bool = False
enable_torch_compile_debug_mode: bool = False
torch_compile_max_bs: int = 32
piecewise_cuda_graph_max_tokens: Optional[int] = None
piecewise_cuda_graph_tokens: Optional[List[int]] = None
piecewise_cuda_graph_compiler: str = "eager"
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 8
triton_attention_split_tile_size: Optional[int] = None
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
enable_weights_cpu_backup: bool = False
enable_draft_weights_cpu_backup: bool = False
allow_auto_truncate: bool = False
enable_custom_logit_processor: bool = False
flashinfer_mla_disable_ragged: bool = False
disable_shared_experts_fusion: bool = False
enforce_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
keep_mm_feature_on_device: bool = False
enable_return_hidden_states: bool = False
enable_return_routed_experts: bool = False
enable_return_indexer_topk: bool = False
scheduler_recv_interval: int = 1
numa_node: Optional[List[int]] = None
enable_deterministic_inference: bool = False
rl_on_policy_target: Optional[str] = None
enable_attn_tp_input_scattered: bool = False
gc_threshold: Optional[List[int]] = None
# Context parallelism used in the long sequence prefill phase of DeepSeek v3.2
enable_nsa_prefill_context_parallel: bool = False
nsa_prefill_cp_mode: str = "round-robin-split"
enable_fused_qk_norm_rope: bool = False
enable_precise_embedding_interpolation: bool = False
enable_fused_moe_sum_all_reduce: bool = False
# Context parallelism
enable_prefill_context_parallel: bool = False
prefill_cp_mode: str = "in-seq-split"
# Dynamic batch tokenizer
enable_dynamic_batch_tokenizer: bool = False
dynamic_batch_tokenizer_batch_size: int = 32
dynamic_batch_tokenizer_batch_timeout: float = 0.002
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
# None means dump all layers.
debug_tensor_dump_layers: Optional[List[int]] = None
# TODO(guoyuhong): clean the old dumper code.
debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: Literal["null", "prefill", "decode"] = "null"
disaggregation_transfer_backend: str = "mooncake"
disaggregation_bootstrap_port: int = 8998
disaggregation_ib_device: Optional[str] = None
disaggregation_decode_enable_radix_cache: bool = False
disaggregation_decode_enable_offload_kvcache: bool = False
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
# FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1
# Encode prefill disaggregation
encoder_only: bool = False
language_only: bool = False
encoder_transfer_backend: str = ENCODER_TRANSFER_BACKEND_CHOICES[0]
encoder_urls: List[str] = dataclasses.field(default_factory=list)
enable_adaptive_dispatch_to_encoder: bool = False
# For model weight update and weight loading
custom_weight_loader: Optional[List[str]] = None
weight_loader_disable_mmap: bool = False
weight_loader_prefetch_checkpoints: bool = False
weight_loader_prefetch_num_threads: int = 4
weight_loader_drop_cache_after_load: bool = False
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
remote_instance_weight_loader_backend: Literal[
"transfer_engine", "nccl", "modelexpress"
] = "nccl"
remote_instance_weight_loader_start_seed_via_transfer_engine: bool = False
engine_info_bootstrap_port: int = 6789
modelexpress_config: Optional[str] = None
# For PD-Multiplexing
enable_pdmux: bool = False
pdmux_config_path: Optional[str] = None
sm_group_num: int = 8
# For Multi-Modal
enable_broadcast_mm_inputs_process: bool = False
enable_prefix_mm_cache: bool = False
mm_enable_dp_encoder: bool = False
mm_process_config: Optional[Dict[str, Any]] = None
limit_mm_data_per_request: Optional[Union[str, Dict[str, int]]] = None
enable_mm_global_cache: bool = False
# For checkpoint decryption
decrypted_config_file: Optional[str] = None
decrypted_draft_config_file: Optional[str] = None
# For forward hooks
forward_hooks: Optional[List[dict[str, Any]]] = None
# For communications compression
enable_quant_communications: Optional[bool] = False
# For msProbe
msprobe_dump_config: Optional[str] = None
def __post_init__(self):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
self._maybe_download_model_for_runai()
# Normalize load balancing defaults early (before dummy-model short-circuit).
self._handle_load_balance_method()
# Validate mm_process_config before dummy-model early return.
self._handle_multimodal()
# Validate SSL arguments early (before dummy-model short-circuit).
self._handle_ssl_validation()
# Validate PD disaggregation flags early (before dummy-model short-circuit).
self._handle_pd_disaggregation()
if self.model_path.lower() in ["none", "dummy"]:
# Skip for dummy models
return
# Handle deprecated arguments.
self._handle_deprecated_args()
# Handle deprecated environment variables for prefill delayer.
self._handle_prefill_delayer_env_compat()
# Resolve --quantization unquant: explicitly opt out of quantization.
# Convert to None now (before model config validation), but record
# the intent so auto-detection in _handle_model_specific_adjustments
# does not override it.
if self.quantization == "unquant":
self.quantization = None
self._quantization_explicitly_unset = True
else:
self._quantization_explicitly_unset = False
# Set missing default values.
self._handle_missing_default_values()
# Handle device-specific backends.
self._handle_hpu_backends()
self._handle_cpu_backends()
self._handle_npu_backends()
self._handle_mps_backends()
self._handle_xpu_backends()
# Allow OOT platform plugins to apply server args defaults.
from sglang.srt.platforms import current_platform
current_platform.apply_server_args_defaults(self)
# Handle piecewise CUDA graph.
self._handle_piecewise_cuda_graph()
# Get GPU memory capacity, which is a common dependency for several configuration steps.
gpu_mem = get_device_memory_capacity(self.device)
# Handle memory-related, chunked prefill, and CUDA graph batch size configurations.
self._handle_gpu_memory_settings(gpu_mem)
# Apply model-specific adjustments.
self._handle_model_specific_adjustments()
# Set kernel backends.
self._handle_sampling_backend()
self._handle_attention_backend_compatibility()
self._handle_mamba_backend()
self._handle_linear_attn_backend()
self._handle_kv4_compatibility()
self._handle_page_size()
self._handle_amd_specifics()
self._handle_nccl_pre_warm()
self._handle_grammar_backend()
# Handle multi-item scoring constraints. Must run after the above so
# the final attention backend and chunked_prefill_size are in effect.
self._handle_multi_item_scoring()
# Handle Hicache settings.
self._handle_hicache()
# Handle data parallelism.
self._handle_data_parallelism()
# Handle context parallelism.
self._handle_context_parallelism()
# Handle MoE configurations.
self._handle_moe_kernel_config()
self._handle_a2a_moe()
self._handle_eplb_and_dispatch()
self._handle_expert_distribution_metrics()
self._handle_elastic_ep()
# Handle pipeline parallelism.
self._handle_pipeline_parallelism()
# Handle speculative decoding logic.
self._handle_speculative_decoding()
# Handle model loading format.
self._handle_load_format()
# Handle Encoder disaggregation.
self._handle_encoder_disaggregation()
# Validate tokenizer settings.
self._handle_tokenizer_batching()
# Propagate environment variables.
self._handle_environment_variables()
# Validate cache settings.
self._handle_cache_compatibility()
# Handle deterministic inference.
self._handle_deterministic_inference()
# Handle diffusion LLM inference.
self._handle_dllm_inference()
# Handle debug utilities.
self._handle_debug_utils()
# Handle any other necessary validations.
self._handle_other_validations()
def _maybe_download_model_for_runai(self):
if is_runai_obj_uri(self.model_path):
ObjectStorageModel.download_and_get_path(self.model_path)
if (
self.tokenizer_path is not None
and is_runai_obj_uri(self.tokenizer_path)
and self.tokenizer_path != self.model_path
):
ObjectStorageModel.download_and_get_path(self.tokenizer_path)
def _handle_load_balance_method(self):
if self.disaggregation_mode not in ("null", "prefill", "decode"):
raise ValueError(
f"Invalid disaggregation_mode={self.disaggregation_mode!r}"
)
if self.load_balance_method == "auto":
# Default behavior:
# - non-PD: round_robin
# - PD prefill: follow_bootstrap_room
# - PD decode: round_robin
self.load_balance_method = (
"follow_bootstrap_room"
if self.disaggregation_mode == "prefill"
else "round_robin"
)
return
def _handle_ssl_validation(self):
"""Ensure SSL arguments are consistent and referenced files exist."""
if self.ssl_keyfile and not self.ssl_certfile:
raise ValueError(
"--ssl-keyfile requires --ssl-certfile to be specified as well."
)
if self.ssl_certfile and not self.ssl_keyfile:
raise ValueError(
"--ssl-certfile requires --ssl-keyfile to be specified as well."
)
if not self.ssl_certfile and not self.ssl_keyfile:
if self.ssl_ca_certs: