Skip to content

Integration with elasticmem#13581

Open
pansicheng wants to merge 5 commits intosgl-project:mainfrom
pansicheng:emem
Open

Integration with elasticmem#13581
pansicheng wants to merge 5 commits intosgl-project:mainfrom
pansicheng:emem

Conversation

@pansicheng
Copy link
Copy Markdown
Collaborator

Motivation

This PR implements dynamic scaling between different attention-type pools within the hybrid model in sglang, based on elasticmem.

______________________________________________________________________a-Flowchart (10)

Modifications

Accuracy Tests

Benchmarking and Profiling

export SGLANG_ELASTIC_MEM_POOL=true
export SGLANG_RATIO=1.0
nohup python3 -m sglang.launch_server \
  --log-level debug \
  --model /home/t4/models/lvm-data/Llama-4-Scout-17B-16E-Instruct \
  --tp 2 \
  --attention-backend fa3 \
  --hybrid-kvcache-ratio ${SGLANG_RATIO} \
  --context-length 200000 \
  > nohup.emem.${SGLANG_ELASTIC_MEM_POOL}.ratio.${SGLANG_RATIO}.out 2>&1 \
  &

export SGLANG_ELASTIC_MEM_POOL=true
export SGLANG_RATIO=1.0
nohup python3 -m sglang.bench_serving --backend sglang \
  --dataset-name random --dataset-path /home/t4/models/lvm-data/ShareGPT_V3_unfiltered_cleaned_split.json \
  --num-prompts 1024 --random-input 1024 --random-output 1024 --random-range-ratio 1 \
  --max-concurrency 128 \
  > nohup.bench.${SGLANG_ELASTIC_MEM_POOL}.ratio.${SGLANG_RATIO}.out 2>&1 \
  &
image
  • Horizontal axis: Each time step represents a log entry for Prefill/Decode.
  • Vertical axis: As shown in the figure, it represents:
    • Token usage for different pools,
    • Real-time running requests,
    • Generation throughput during decoding (set to 0 during prefill).
  • First row: Current static pool configuration with --hybrid-kvcache-ratio=0.6 (relatively balanced allocation between the full and swa pools).
  • Second row: Current static pool configuration with --hybrid-kvcache-ratio=1.0 (most GPU memory allocated to the full pool).
  • Third row: Elastic pool configuration with --hybrid-kvcache-ratio=1.0 (initial allocation favors the full pool but supports dynamic runtime adjustments).

Checklist

Comment thread python/sglang/srt/managers/scheduler.py Outdated
Comment thread scripts/emem/plot/main.py
# TODO: a more efficient way
@override
def alloc(self, need_size: int):
self.merge_and_sort_free()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it more efficient

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we sort only during defragmentation

if self.token_usage() > 0.9:
return False

self.evict(self.evictable_size())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does can_unmap need to evict and merge_and_sort, since both seems to be time consuming

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now can_unmap skips eviction and merge_and_sort, using an unused_pages tensor to track consecutive tail pages

Use oversubscribe instead of expand

Implement elastic memory pool for KV cache

Implement elastic memory pool allocator

ElasticMempoolOrchestrator

Fix resizing timing of elastic mempool during prefill batch creation

Fix can_unmap

Simplify reduction

Enhance elastic memory management with free_all, improved token tracking, and optimized orchestration

Add CUDA synchronization in orchestrator resize operations

Clean code
Comment thread python/sglang/srt/mem_cache/elastic/elasticmem_orchestrator.py
Comment thread python/sglang/srt/mem_cache/elastic/elasticmem_orchestrator.py Outdated
Comment thread python/sglang/srt/mem_cache/elastic/elasticmem_orchestrator.py Outdated
@hanming-lu
Copy link
Copy Markdown
Collaborator

hanming-lu commented Dec 13, 2025

Nice. I see how we try to improve max running batch size with this. In parallel, do we target to improve prefix cache hit rate with this as well by analyzing which of swa or full causing cache hit miss?

@pansicheng
Copy link
Copy Markdown
Collaborator Author

Nice. I see how we try to improve max running batch size with this. In parallel, do we target to improve prefix cache hit rate with this as well by analyzing which of swa or full causing cache hit miss?

@hanming-lu No problem. The current PR focuses on balancing pool usage to maximize batch size when some pools near capacity. Next, we’ll monitor cache hit rates per pool and optimize scaling strategies to boost hit rates under balanced loads. Metrics and adaptive scaling will need further design, let’s tackle this next!

YAMY1234 added a commit to YAMY1234/sglang that referenced this pull request Mar 30, 2026
After merging upstream main into the PR sgl-project#13581 branch, several
compatibility issues arose due to SWA code being refactored from
memory_pool.py to swa_memory_pool.py:

- Add page_size parameter to SWATokenToKVPoolAllocator in allocator.py
- Fix elastic_allocator.py to import SWATokenToKVPoolAllocator from
  swa_memory_pool instead of allocator (fixes isinstance check in
  SWARadixCache)
- Rewrite ElasticSWATokenToKVPoolAllocator to replace parent allocators
  post-init instead of overriding _create_allocator (which parent no
  longer calls)
- Rewrite ElasticSWAKVPool to pass ElasticMHATokenToKVPool as pool
  class and recreate pools with pool_name parameter
- Fix isinstance check in model_runner_kv_cache_mixin.py (use isinstance
  instead of __class__ ==)
- Add missing get_float_env_var utility function to utils/common.py

Made-with: Cursor
@ZelinMa557
Copy link
Copy Markdown

hi, will this feature support GDN models and mamba models?

@pansicheng
Copy link
Copy Markdown
Collaborator Author

hi, will this feature support GDN models and mamba models?

There’s a PR for Qwen3-Next support here: #14597. I’ll try to move it forward as soon as possible.

@ZelinMa557
Copy link
Copy Markdown

hi, will this feature support GDN models and mamba models?

There’s a PR for Qwen3-Next support here: #14597. I’ll try to move it forward as soon as possible.

Thanks for your reply, I'm very interested in supporting dynamic memory pool for mamba/GDN models, is there anything I can help with?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants