Skip to content

[Bug] GLM-4.5-Air-FP8 can't run consistently. OutOfResources error #8513

@17Reset

Description

@17Reset

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

SGLang crashes directly when the front-end submits a request with a context

Running Log:

xaccel@xaccel:/xrepo/Athena/Agent/SGLang$ ./sglang_run.sh
[2025-07-29 15:04:43] server_args=ServerArgs(model_path='/xrepo/Athena/Model/GLM/GLM-4.5-Air-FP8', tokenizer_path='/xrepo/Athena/Model/GLM/GLM-4.5-Air-FP8', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='0.0.0.0', port=30000, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.7, max_running_requests=128, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=4, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=183703364, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='/xrepo/Athena/Model/GLM/GLM-4.5-Air-FP8', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser='glm45', tool_call_parser=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, enable_ep_moe=False, enable_deepep_moe=False, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='', hicache_storage_backend=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=None, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3)
[2025-07-29 15:04:49 TP0] Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-07-29 15:04:49 TP0] Init torch distributed begin.
[2025-07-29 15:04:50 TP0] sglang is using nccl==2.26.2
[2025-07-29 15:04:50 TP3] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-07-29 15:04:50 TP1] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-07-29 15:04:50 TP2] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-07-29 15:04:50 TP0] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-07-29 15:04:50 TP0] Init torch distributed ends. mem usage=0.19 GB
[2025-07-29 15:04:51 TP0] Load weight begin. avail mem=46.21 GB
Loading safetensors checkpoint shards:   0% Completed | 0/47 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   9% Completed | 4/47 [00:00<00:02, 14.66it/s]
Loading safetensors checkpoint shards:  30% Completed | 14/47 [00:00<00:00, 43.62it/s]
Loading safetensors checkpoint shards:  51% Completed | 24/47 [00:00<00:00, 61.17it/s]
Loading safetensors checkpoint shards:  74% Completed | 35/47 [00:00<00:00, 75.50it/s]
Loading safetensors checkpoint shards:  94% Completed | 44/47 [00:00<00:00, 79.47it/s]
Loading safetensors checkpoint shards: 100% Completed | 47/47 [00:00<00:00, 65.09it/s]

[2025-07-29 15:05:10 TP0] Load weight end. type=Glm4MoeForCausalLM, dtype=torch.bfloat16, avail mem=20.75 GB, mem usage=25.46 GB.
[2025-07-29 15:05:10 TP2] KV Cache is allocated. #tokens: 156868, K size: 3.44 GB, V size: 3.44 GB
[2025-07-29 15:05:10 TP0] KV Cache is allocated. #tokens: 156868, K size: 3.44 GB, V size: 3.44 GB
[2025-07-29 15:05:10 TP1] KV Cache is allocated. #tokens: 156868, K size: 3.44 GB, V size: 3.44 GB
[2025-07-29 15:05:10 TP0] Memory pool end. avail mem=13.68 GB
[2025-07-29 15:05:10 TP3] KV Cache is allocated. #tokens: 156868, K size: 3.44 GB, V size: 3.44 GB
[2025-07-29 15:05:10 TP0] Capture cuda graph begin. This can take up to several minutes. avail mem=13.17 GB
[2025-07-29 15:05:11 TP0] Capture cuda graph bs [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128]
Capturing batches (bs=128 avail_mem=13.09 GB):   0%|                                                   | 0/19 [00:00<?, ?it/s][2025-07-29 15:05:12 TP3] Using default MoE kernel config. Performance might be sub-optimal! Config file not found at /xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=129,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
[2025-07-29 15:05:12 TP0] Using default MoE kernel config. Performance might be sub-optimal! Config file not found at /xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=129,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
[2025-07-29 15:05:12 TP1] Using default MoE kernel config. Performance might be sub-optimal! Config file not found at /xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=129,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
[2025-07-29 15:05:12 TP2] Using default MoE kernel config. Performance might be sub-optimal! Config file not found at /xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=129,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
Capturing batches (bs=1 avail_mem=11.89 GB): 100%|████████████████████████████████████████████| 19/19 [00:18<00:00,  1.04it/s]
[2025-07-29 15:05:29 TP0] Capture cuda graph end. Time elapsed: 18.29 s. mem usage=1.30 GB. avail mem=11.87 GB.
[2025-07-29 15:05:29 TP0] max_total_num_tokens=156868, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=131072, available_gpu_mem=11.87 GB
[2025-07-29 15:05:30] INFO:     Started server process [867080]
[2025-07-29 15:05:30] INFO:     Waiting for application startup.
[2025-07-29 15:05:30] INFO:     Application startup complete.
[2025-07-29 15:05:30] INFO:     Uvicorn running on http://0.0.0.0:30000 (Press CTRL+C to quit)
[2025-07-29 15:05:31] INFO:     127.0.0.1:54580 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-07-29 15:05:31 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-07-29 15:05:32] INFO:     127.0.0.1:54596 - "POST /generate HTTP/1.1" 200 OK
[2025-07-29 15:05:32] The server is fired up and ready to roll!
/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/pydantic/_internal/_generate_schema.py:2225: UnsupportedFieldAttributeWarning: The 'deprecated' attribute with value 'max_tokens is deprecated in favor of the max_completion_tokens field' was provided to the `Field()` function, which has no effect in the context it was used. 'deprecated' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
  warnings.warn(
[2025-07-29 15:06:06] INFO:     192.168.18.118:53796 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-07-29 15:06:06 TP0] Prefill batch. #new-seq: 1, #new-token: 20, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-07-29 15:06:07 TP0] Decode batch. #running-req: 1, #token: 53, token usage: 0.00, cuda graph: True, gen throughput (token/s): 1.07, #queue-req: 0,
[2025-07-29 15:06:08 TP0] Decode batch. #running-req: 1, #token: 93, token usage: 0.00, cuda graph: True, gen throughput (token/s): 62.02, #queue-req: 0,
[2025-07-29 15:06:08 TP0] Decode batch. #running-req: 1, #token: 133, token usage: 0.00, cuda graph: True, gen throughput (token/s): 67.87, #queue-req: 0,
[2025-07-29 15:06:09 TP0] Decode batch. #running-req: 1, #token: 173, token usage: 0.00, cuda graph: True, gen throughput (token/s): 63.40, #queue-req: 0,
[2025-07-29 15:06:09 TP0] Decode batch. #running-req: 1, #token: 213, token usage: 0.00, cuda graph: True, gen throughput (token/s): 60.14, #queue-req: 0,
[2025-07-29 15:06:10 TP0] Decode batch. #running-req: 1, #token: 253, token usage: 0.00, cuda graph: True, gen throughput (token/s): 73.58, #queue-req: 0,
[2025-07-29 15:06:11 TP0] Decode batch. #running-req: 1, #token: 293, token usage: 0.00, cuda graph: True, gen throughput (token/s): 62.62, #queue-req: 0,
[2025-07-29 15:06:11 TP0] Decode batch. #running-req: 1, #token: 333, token usage: 0.00, cuda graph: True, gen throughput (token/s): 56.69, #queue-req: 0,
[2025-07-29 15:06:12 TP0] Decode batch. #running-req: 1, #token: 373, token usage: 0.00, cuda graph: True, gen throughput (token/s): 65.00, #queue-req: 0,
[2025-07-29 15:06:13 TP0] Decode batch. #running-req: 1, #token: 413, token usage: 0.00, cuda graph: True, gen throughput (token/s): 58.00, #queue-req: 0,
[2025-07-29 15:06:13 TP0] Decode batch. #running-req: 1, #token: 453, token usage: 0.00, cuda graph: True, gen throughput (token/s): 60.50, #queue-req: 0,
[2025-07-29 15:06:14 TP0] Decode batch. #running-req: 1, #token: 493, token usage: 0.00, cuda graph: True, gen throughput (token/s): 63.47, #queue-req: 0,
[2025-07-29 15:06:15 TP0] Decode batch. #running-req: 1, #token: 533, token usage: 0.00, cuda graph: True, gen throughput (token/s): 56.36, #queue-req: 0,
[2025-07-29 15:06:15 TP0] Decode batch. #running-req: 1, #token: 573, token usage: 0.00, cuda graph: True, gen throughput (token/s): 62.97, #queue-req: 0,
[2025-07-29 15:06:16 TP0] Decode batch. #running-req: 1, #token: 613, token usage: 0.00, cuda graph: True, gen throughput (token/s): 62.31, #queue-req: 0,
[2025-07-29 15:06:17 TP0] Decode batch. #running-req: 1, #token: 653, token usage: 0.00, cuda graph: True, gen throughput (token/s): 51.59, #queue-req: 0,
[2025-07-29 15:06:17 TP0] Decode batch. #running-req: 1, #token: 693, token usage: 0.00, cuda graph: True, gen throughput (token/s): 63.50, #queue-req: 0,
[2025-07-29 15:06:18 TP0] Decode batch. #running-req: 1, #token: 733, token usage: 0.00, cuda graph: True, gen throughput (token/s): 60.90, #queue-req: 0,
[2025-07-29 15:06:19 TP0] Decode batch. #running-req: 1, #token: 773, token usage: 0.00, cuda graph: True, gen throughput (token/s): 56.30, #queue-req: 0,
[2025-07-29 15:06:19 TP0] Decode batch. #running-req: 1, #token: 813, token usage: 0.01, cuda graph: True, gen throughput (token/s): 67.61, #queue-req: 0,
[2025-07-29 15:06:20 TP0] Decode batch. #running-req: 1, #token: 853, token usage: 0.01, cuda graph: True, gen throughput (token/s): 59.10, #queue-req: 0,
[2025-07-29 15:06:21 TP0] Decode batch. #running-req: 1, #token: 893, token usage: 0.01, cuda graph: True, gen throughput (token/s): 59.77, #queue-req: 0,
[2025-07-29 15:06:21 TP0] Decode batch. #running-req: 1, #token: 933, token usage: 0.01, cuda graph: True, gen throughput (token/s): 70.48, #queue-req: 0,
[2025-07-29 15:06:22 TP0] Decode batch. #running-req: 1, #token: 973, token usage: 0.01, cuda graph: True, gen throughput (token/s): 56.92, #queue-req: 0,
[2025-07-29 15:06:23 TP0] Decode batch. #running-req: 1, #token: 1013, token usage: 0.01, cuda graph: True, gen throughput (token/s): 65.28, #queue-req: 0,
[2025-07-29 15:06:23 TP0] Decode batch. #running-req: 1, #token: 1053, token usage: 0.01, cuda graph: True, gen throughput (token/s): 66.55, #queue-req: 0,
[2025-07-29 15:06:24 TP0] Decode batch. #running-req: 1, #token: 1093, token usage: 0.01, cuda graph: True, gen throughput (token/s): 57.44, #queue-req: 0,
[2025-07-29 15:06:24 TP0] Decode batch. #running-req: 1, #token: 1133, token usage: 0.01, cuda graph: True, gen throughput (token/s): 58.79, #queue-req: 0,
[2025-07-29 15:06:25 TP0] Decode batch. #running-req: 1, #token: 1173, token usage: 0.01, cuda graph: True, gen throughput (token/s): 62.77, #queue-req: 0,
[2025-07-29 15:06:26 TP0] Decode batch. #running-req: 1, #token: 1213, token usage: 0.01, cuda graph: True, gen throughput (token/s): 56.79, #queue-req: 0,
[2025-07-29 15:06:26 TP0] Decode batch. #running-req: 1, #token: 1253, token usage: 0.01, cuda graph: True, gen throughput (token/s): 64.95, #queue-req: 0,
[2025-07-29 15:06:27 TP0] Decode batch. #running-req: 1, #token: 1293, token usage: 0.01, cuda graph: True, gen throughput (token/s): 67.43, #queue-req: 0,
[2025-07-29 15:06:28 TP0] Decode batch. #running-req: 1, #token: 1333, token usage: 0.01, cuda graph: True, gen throughput (token/s): 74.53, #queue-req: 0,
[2025-07-29 15:06:28 TP0] Decode batch. #running-req: 1, #token: 1373, token usage: 0.01, cuda graph: True, gen throughput (token/s): 63.44, #queue-req: 0,
[2025-07-29 15:06:29 TP0] Decode batch. #running-req: 1, #token: 1413, token usage: 0.01, cuda graph: True, gen throughput (token/s): 64.64, #queue-req: 0,
[2025-07-29 15:06:29 TP0] Decode batch. #running-req: 1, #token: 1453, token usage: 0.01, cuda graph: True, gen throughput (token/s): 69.11, #queue-req: 0,
[2025-07-29 15:06:30 TP0] Decode batch. #running-req: 1, #token: 1493, token usage: 0.01, cuda graph: True, gen throughput (token/s): 58.16, #queue-req: 0,
[2025-07-29 15:06:31 TP0] Decode batch. #running-req: 1, #token: 1533, token usage: 0.01, cuda graph: True, gen throughput (token/s): 67.44, #queue-req: 0,
[2025-07-29 15:06:31 TP0] Decode batch. #running-req: 1, #token: 1573, token usage: 0.01, cuda graph: True, gen throughput (token/s): 69.98, #queue-req: 0,
[2025-07-29 15:06:32 TP0] Decode batch. #running-req: 1, #token: 1613, token usage: 0.01, cuda graph: True, gen throughput (token/s): 66.60, #queue-req: 0,
[2025-07-29 15:06:32 TP0] Decode batch. #running-req: 1, #token: 1653, token usage: 0.01, cuda graph: True, gen throughput (token/s): 63.13, #queue-req: 0,
[2025-07-29 15:06:33 TP0] Decode batch. #running-req: 1, #token: 1693, token usage: 0.01, cuda graph: True, gen throughput (token/s): 60.80, #queue-req: 0,
[2025-07-29 15:06:34 TP0] Decode batch. #running-req: 1, #token: 1733, token usage: 0.01, cuda graph: True, gen throughput (token/s): 61.92, #queue-req: 0,
[2025-07-29 15:06:34 TP0] Decode batch. #running-req: 1, #token: 1773, token usage: 0.01, cuda graph: True, gen throughput (token/s): 66.09, #queue-req: 0,
[2025-07-29 15:06:35 TP0] Decode batch. #running-req: 1, #token: 1813, token usage: 0.01, cuda graph: True, gen throughput (token/s): 68.39, #queue-req: 0,
[2025-07-29 15:06:36 TP0] Decode batch. #running-req: 1, #token: 1853, token usage: 0.01, cuda graph: True, gen throughput (token/s): 64.43, #queue-req: 0,
[2025-07-29 15:06:36 TP0] Decode batch. #running-req: 1, #token: 1893, token usage: 0.01, cuda graph: True, gen throughput (token/s): 70.14, #queue-req: 0,
[2025-07-29 15:06:37 TP0] Decode batch. #running-req: 1, #token: 1933, token usage: 0.01, cuda graph: True, gen throughput (token/s): 63.96, #queue-req: 0,
[2025-07-29 15:06:37 TP0] Decode batch. #running-req: 1, #token: 1973, token usage: 0.01, cuda graph: True, gen throughput (token/s): 61.67, #queue-req: 0,
[2025-07-29 15:06:38 TP0] Decode batch. #running-req: 1, #token: 2013, token usage: 0.01, cuda graph: True, gen throughput (token/s): 64.69, #queue-req: 0,
[2025-07-29 15:06:39 TP0] Decode batch. #running-req: 1, #token: 2053, token usage: 0.01, cuda graph: True, gen throughput (token/s): 59.01, #queue-req: 0,
[2025-07-29 15:06:39 TP0] Decode batch. #running-req: 1, #token: 2093, token usage: 0.01, cuda graph: True, gen throughput (token/s): 65.50, #queue-req: 0,
[2025-07-29 15:06:40 TP0] Decode batch. #running-req: 1, #token: 2133, token usage: 0.01, cuda graph: True, gen throughput (token/s): 71.42, #queue-req: 0,
[2025-07-29 15:06:41 TP0] Decode batch. #running-req: 1, #token: 2173, token usage: 0.01, cuda graph: True, gen throughput (token/s): 58.20, #queue-req: 0,
[2025-07-29 15:06:41 TP0] Decode batch. #running-req: 1, #token: 2213, token usage: 0.01, cuda graph: True, gen throughput (token/s): 62.94, #queue-req: 0,
[2025-07-29 15:06:42 TP0] Decode batch. #running-req: 1, #token: 2253, token usage: 0.01, cuda graph: True, gen throughput (token/s): 63.71, #queue-req: 0,
[2025-07-29 15:06:42 TP0] Decode batch. #running-req: 1, #token: 2293, token usage: 0.01, cuda graph: True, gen throughput (token/s): 63.38, #queue-req: 0,
[2025-07-29 15:06:51] INFO:     192.168.18.118:53823 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-07-29 15:06:51 TP0] Prefill batch. #new-seq: 1, #new-token: 1938, #cached-token: 22, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-07-29 15:06:51 TP0] TpModelWorkerClient hit an exception: Traceback (most recent call last):
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/managers/tp_worker_overlap_thread.py", line 140, in forward_thread_func
    self.forward_thread_func_()
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/managers/tp_worker_overlap_thread.py", line 175, in forward_thread_func_
    self.worker.forward_batch_generation(
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/managers/tp_worker.py", line 229, in forward_batch_generation
    logits_output, can_run_cuda_graph = self.model_runner.forward(
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/model_executor/model_runner.py", line 1606, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/model_executor/model_runner.py", line 1651, in _forward_raw
    ret = self.forward_extend(
          ^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/model_executor/model_runner.py", line 1551, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/models/deepseek_v2.py", line 2177, in forward
    hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/models/deepseek_v2.py", line 2070, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/models/glm4_moe.py", line 643, in forward
    hidden_states = self.mlp(hidden_states, forward_batch)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/models/deepseek_v2.py", line 463, in forward
    return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/models/deepseek_v2.py", line 508, in forward_normal
    final_hidden_states = self.experts(**kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 578, in forward
    final_hidden_states = self.quant_method.apply(
                          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py", line 281, in apply
    return fused_experts(
           ^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 1168, in fused_experts
    torch.ops.sglang.inplace_fused_experts(
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/torch/_ops.py", line 1158, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 1000, in inplace_fused_experts
    fused_experts_impl(
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 1436, in fused_experts_impl
    invoke_fused_moe_kernel(
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 741, in invoke_fused_moe_kernel
    fused_moe_kernel[grid](
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/triton/runtime/jit.py", line 347, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/triton/runtime/jit.py", line 591, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
    ^^^^^^^^^^
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/triton/compiler/compiler.py", line 413, in __getattribute__
    self._init_handles()
  File "/xrepo/Athena/Agent/SGLang/sglang_env/lib/python3.12/site-packages/triton/compiler/compiler.py", line 401, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

[2025-07-29 15:06:51] Received sigquit from a child process. It usually means the child failed.
[2025-07-29 15:06:51] Dumping requests before crash. self.crash_dump_folder=None
./sglang_run.sh: line 19: 867080 Killed                  python -m sglang.launch_server --model-path $MODEL_PATH --tp-size $GPU_NUM --host $SERVER_IPV4 --port $SERVER_PORT --max-running-requests $MAX_REQUESTS --mem-fraction-static $MEM_FRACTION_STATIC --reasoning-parser $REASONING_PARSER

Reproduction

#!/bin/bash


GPU_NUM=4
SERVER_IPV4="0.0.0.0"
SERVER_PORT=30000
MODEL_PATH="/xrepo/Athena/Model/GLM/GLM-4.5-Air-FP8"
MAX_REQUESTS=128
REASONING_PARSER="glm45"	# deepseek-r1 qwen3 kimi glm45
MEM_FRACTION_STATIC="0.7"


# Active Python virtual environment
source sglang_env/bin/activate


# Startup api server
python -m sglang.launch_server --model-path $MODEL_PATH --tp-size $GPU_NUM --host $SERVER_IPV4 --port $SERVER_PORT --max-running-requests $MAX_REQUESTS --mem-fraction-static $MEM_FRACTION_STATIC --reasoning-parser $REASONING_PARSER

Environment

xaccel@xaccel:/xrepo/Athena/Agent/SGLang$ source sglang_env/bin/activate
(sglang_env) xaccel@xaccel:/xrepo/Athena/Agent/SGLang$ python -m sglang.check_env
Python: 3.12.3 (main, Jun 18 2025, 17:59:45) [GCC 13.3.0]
CUDA available: True
GPU 0,1,2,3: NVIDIA RTX 6000 Ada Generation
GPU 0,1,2,3 Compute Capability: 8.9
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.8, V12.8.93
CUDA Driver Version: 570.172.08
PyTorch: 2.7.1+cu126
sglang: 0.4.9.post5
sgl_kernel: 0.2.7
flashinfer_python: 0.2.9rc2
triton: 3.3.1
transformers: 4.54.0
torchao: 0.9.0
numpy: 2.3.2
aiohttp: 3.12.15
fastapi: 0.116.1
hf_transfer: 0.1.9
huggingface_hub: 0.35.0rc0
interegular: 0.3.3
modelscope: 1.28.1
orjson: 3.11.1
outlines: 0.1.11
packaging: 25.0
psutil: 7.0.0
pydantic: 2.12.0a1
python-multipart: 0.0.20
pyzmq: 27.0.0
uvicorn: 0.35.0
uvloop: 0.21.0
vllm: Module Not Found
xgrammar: 0.1.21
openai: 1.97.1
tiktoken: 0.9.0
anthropic: 0.60.0
litellm: 1.74.9.post1
decord: 0.6.0
NVIDIA Topology:
        GPU0    GPU1    GPU2    GPU3    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NODE    SYS     SYS     20-29   2               N/A
GPU1    NODE     X      SYS     SYS     20-29   2               N/A
GPU2    SYS     SYS      X      NODE    60-69   6               N/A
GPU3    SYS     SYS     NODE     X      60-69   6               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

ulimit soft: 1024

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions