Skip to content

NixlKVManager: async multi-threaded KV transfer#20680

Open
usernamehaha2022 wants to merge 7 commits intosgl-project:mainfrom
usernamehaha2022:nixl-async-transfer
Open

NixlKVManager: async multi-threaded KV transfer#20680
usernamehaha2022 wants to merge 7 commits intosgl-project:mainfrom
usernamehaha2022:nixl-async-transfer

Conversation

@usernamehaha2022
Copy link
Copy Markdown

Motivation

This PR improves the performance of NixlKVManager by making KV transfer asynchronous and multi-threaded on the prefill node. Previously, add_transfer_request performed each chunk transfer synchronously and the caller (NixlKVSender) had to track and poll all transfer handles. With many decode instances and chunked transfers, this caused the prefill scheduler to block on transfer completion and limited throughput. This change aligns NIXL with the queue-based, multi-worker transfer design.

Performance

We ran Qwen3-32B PD disaggregation with NIXL and observed a clear improvement in transfer latency via NIXL telemetry:

  • Mean transfer time: 162,225 μs → 41,225 μs (about 4× lower).
  • Distribution: Before, transfer times had high variance with many samples in the 250k–1.2M μs range and a long tail; after the change, the vast majority of samples sit in the 34k–42k μs band with much lower variance and no large outliers.

Async multi-worker transfer removes the synchronous bottleneck on the prefill path: chunks are processed in parallel by worker threads, and decode instances are sharded across queues for better overlap, which explains the lower mean and significantly improved tail (P95/P99) latency.

Modifications

  1. Async transfer with queue + worker pool (PREFILL mode)

    • Introduced multiple FastQueue instances (count controlled by SGLANG_DISAGGREGATION_QUEUE_SIZE) and a ThreadPoolExecutor per queue (total worker count from SGLANG_DISAGGREGATION_THREAD_POOL_SIZE).
    • Added a TransferKVChunk dataclass and daemon transfer_worker threads that consume chunks from the queues and execute send_kvcache / send_kvcache_slice, maybe_send_extra, and send_aux in the worker.
    • Default thread pool size: min(max(4, (0.5 * cpu_count) // 8), 12) when the env var is not set; queue size defaults to env (e.g. 4).
  2. Non-blocking add_transfer_request

    • add_transfer_request no longer performs transfer inline; it enqueues a TransferKVChunk to transfer_queues[bootstrap_room % len(transfer_queues)] and returns None.
    • Workers update request_status (e.g. Transferring, Success, Failed), so the sender no longer needs to hold or poll transfer handles.
  3. NixlKVSender simplifications

    • Removed xfer_handles; poll() now relies on kv_mgr.check_status(bootstrap_room) only.
    • Added clear() to remove bootstrap_room from request_status when appropriate.
    • Last-chunk path no longer deletes request_status in the sender; the worker clears transfer_infos and sets status to Success when the last chunk is done.
  4. Scheduler handling of Bootstrapping

    • In prefill.py, requests in KVPoll.Bootstrapping are now treated as undone (together with WaitingForInput and Transferring) so the scheduler does not consider them complete before transfer progress.

Testing

  • Ran test/registered/disaggregation/test_disaggregation_basic.py locally: 7 tests OK (229s). Example GSM8K metrics: accuracy 0.725, invalid 0.005, output_throughput ~2908 token/s.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so. (/tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci)
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a fundamental architectural change to the NixlKVManager by transforming its KV transfer mechanism from synchronous to asynchronous and multi-threaded. The primary goal is to eliminate performance bottlenecks on the prefill path, which previously occurred due to synchronous chunk transfers blocking the scheduler. By offloading transfer operations to a dedicated queue and worker pool, the system achieves substantial reductions in transfer latency and variance, leading to improved overall throughput and a more responsive prefill process.

Highlights

  • Asynchronous Multi-threaded KV Transfer: Implemented an asynchronous and multi-threaded KV transfer mechanism on the prefill node, utilizing multiple FastQueue instances and ThreadPoolExecutors to process KV chunks in parallel. This significantly reduces the blocking nature of KV transfers.
  • Non-blocking add_transfer_request: The add_transfer_request method now enqueues TransferKVChunk objects into a sharded queue for processing by worker threads, instead of performing synchronous transfers. This allows the caller to proceed without waiting for transfer completion.
  • Simplified NixlKVSender: The NixlKVSender no longer tracks individual transfer handles (xfer_handles) and its poll() method now relies on the NixlKVManager's check_status for transfer progress. A clear() method was added for proper status cleanup.
  • Scheduler Awareness of Bootstrapping: The prefill scheduler now explicitly treats requests in the KVPoll.Bootstrapping state as undone, ensuring that the scheduler waits for transfer progress before considering them complete.
  • Performance Improvement: Observed a significant performance improvement in KV transfer latency, with mean transfer time reduced by approximately 4x (from 162,225 μs to 41,225 μs) and much lower variance, especially in Qwen3-32B PD disaggregation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/disaggregation/nixl/conn.py
    • Added concurrent.futures and os imports for multi-threading and CPU count detection.
    • Imported FastQueue from sglang.srt.disaggregation.common.utils.
    • Defined a new TransferKVChunk dataclass to encapsulate KV transfer details.
    • Initialized transfer_queues and ThreadPoolExecutors in NixlKVManager.__init__ based on environment variables, and started daemon transfer_worker threads.
    • Implemented check_status method in NixlKVManager to retrieve request status.
    • Added transfer_worker method to NixlKVManager to asynchronously process KV chunks from queues, performing send_kvcache, maybe_send_extra, and send_aux operations.
    • Modified NixlKVManager.add_transfer_request to enqueue TransferKVChunk objects and return None, making it non-blocking.
    • Removed self.xfer_handles from NixlKVSender.__init__.
    • Updated NixlKVSender.send to reflect the non-blocking nature of add_transfer_request and removed direct status deletion.
    • Refactored NixlKVSender.poll to query kv_mgr.check_status directly instead of polling local transfer handles.
    • Added clear() method to NixlKVSender for explicit cleanup of request status.
  • python/sglang/srt/disaggregation/prefill.py
    • Updated process_disagg_prefill_inflight_queue to include KVPoll.Bootstrapping in the conditions for undone_reqs, ensuring the scheduler properly handles bootstrapping requests.
Activity
  • The author ran test/registered/disaggregation/test_disaggregation_basic.py locally, with all 7 tests passing in 229 seconds.
  • The author confirmed that the code is formatted according to pre-commit hooks and follows SGLang code style guidance.
  • The author added unit tests by running test_disaggregation_basic.py for regression testing.
  • The author has not yet updated documentation or provided accuracy/speed benchmark results as per the checklist.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance improvement to NixlKVManager by refactoring the KV transfer to be asynchronous and multi-threaded. The change from a synchronous, polling-based approach in the sender to a queue-based worker pool design on the prefill node is well-executed and aligns with modern high-performance patterns. The performance gains are impressive, with a ~4x reduction in mean transfer latency.

The code changes are clear and well-structured. The introduction of TransferKVChunk and the transfer_worker effectively decouples the transfer logic. The simplifications in NixlKVSender are a great consequence of this new design.

I have a couple of suggestions for improvement. One is a critical fix for a potential race condition due to unsynchronized access to a shared dictionary from multiple threads. The other is a minor cleanup of an unused parameter.

Overall, this is an excellent contribution that significantly boosts performance.

Comment on lines +1046 to +1050
def clear(self):
try:
self.kv_mgr.request_status.pop(self.bootstrap_room, None)
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The self.kv_mgr.request_status dictionary is accessed from multiple threads (the main scheduler thread via NixlKVSender and multiple transfer_worker threads) without any synchronization, which can lead to race conditions. The try...except Exception block here is too broad and likely hides such concurrency issues, as pop(key, None) does not raise an exception for a missing key.

To fix this, a threading.Lock should be introduced in NixlKVManager to protect all accesses to self.request_status.

You would need to:

  1. Add self.request_status_lock = threading.Lock() in NixlKVManager.__init__.
  2. Wrap accesses to self.request_status in check_status and update_status with this lock. You will need to override update_status in NixlKVManager.
  3. Update this clear method to use the lock and remove the try-except block.
    def clear(self):
        with self.kv_mgr.request_status_lock:
            self.kv_mgr.request_status.pop(self.bootstrap_room, None)

def check_status(self, bootstrap_room: int):
return self.request_status.get(bootstrap_room, KVPoll.Bootstrapping)

def transfer_worker(self, queue: FastQueue, executor: concurrent.futures.Executor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The executor parameter is passed to transfer_worker but is not used within the function. Similarly, self.executors is created in __init__ but appears to be unused. This seems to be a remnant of a previous design.

To simplify the code, I suggest removing self.executors and the executor parameter.

The __init__ method can be simplified to:

# In NixlKVManager.__init__
...
            self.transfer_queues: List[FastQueue] = [
                FastQueue() for _ in range(transfer_queue_size)
            ]
            # The self.executors list is not used.
            for queue in self.transfer_queues:
                threading.Thread(
                    target=self.transfer_worker, args=(queue,), daemon=True
                ).start()
            self._start_bootstrap_thread()
...

And the signature of transfer_worker should be updated accordingly.

Suggested change
def transfer_worker(self, queue: FastQueue, executor: concurrent.futures.Executor):
def transfer_worker(self, queue: FastQueue):

@iyastreb
Copy link
Copy Markdown

Could you please share the benchmark that you used, to reproduce the results?

@usernamehaha2022
Copy link
Copy Markdown
Author

usernamehaha2022 commented Mar 17, 2026

Could you please share the benchmark that you used, to reproduce the results?

Hello, thank for the review. We conducted PD separation tests on two machines and one machine.
Two machines:
SGLANG_DEEPEP_BF16_DISPATCH=1 python3 -m sglang.launch_server --disaggregation-ib-device mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8 --model-path /models/DeepSeek-R1-W4AFP8/ --served-model-name DeepSeek-R1-W4AFP8 --enable-nccl-nvls --tp-size 8 --dp-size 1 --pp-size 1 --ep-size 8 --dist-init-addr $IP --port 8081 --nnodes 1 --node-rank 0 --trust-remote-code --disaggregation-mode prefill --host 0.0.0.0 --mem-fraction-static 0.75 --chunked-prefill-size -1 --moe-dense-tp-size 1 --max-running-requests 16 --context-length 32768 --watchdog-timeout 3600 --page-size 64 --deepep-mode normal --moe-a2a-backend deepep --log-level info --disaggregation-transfer-backend nixl 2>&1 | tee prefill.log
SGLANG_DEEPEP_BF16_DISPATCH=1 python3 -m sglang.launch_server --disaggregation-ib-device mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8 --model-path /models/DeepSeek-R1-W4AFP8/ --served-model-name DeepSeek-R1-W4AFP8 --enable-nccl-nvls --tp-size 8 --dp-size 1 --ep-size 8 --dist-init-addr $IP --nnodes 1 --node-rank 0 --disaggregation-mode decode --watchdog-timeout 3600 --host 0.0.0.0 --port 8081 --trust-remote-code --mem-fraction-static 0.75 --moe-dense-tp-size 1 --max-running-requests 16 --context-length 32768 --log-level info --cuda-graph-max-bs 16 --page-size 64 --moe-a2a-backend deepep --deepep-mode low_latency --disaggregation-transfer-backend nixl 2>&1 | tee decode.log
python3 -m sglang.bench_serving --model /models/DeepSeek-R1-W4AFP8 --served-model-name DeepSeek-R1-W4AFP8 --backend sglang-oai-chat --host $IP --port 8000 --request-rate 2 --warmup-requests 10 --num-prompts 50 --random-input-len 10240 --random-output-len 50 --random-range-ratio 1 --seed 42 --dataset-name random --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json

One machine:
CUDA_VISIBLE_DEVICES=0,1,2,3 UCX_TLS=^cuda_ipc python3 -m sglang.launch_server
--disaggregation-ib-device mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3
--model-path /models/DeepSeek-R1-Distill-Qwen-32B
--served-model-name DeepSeek-R1-Distill-Qwen-32B
--enable-nccl-nvls
--tp-size 4
--dist-init-addr $IP:29500
--port 8081
--nnodes 1
--node-rank 0
--trust-remote-code
--host 0.0.0.0
--disaggregation-mode prefill
--chunked-prefill-size -1
--mem-fraction-static 0.75
--max-running-requests 16
--context-length 32768
--watchdog-timeout 3600
--page-size 64
--disable-radix-cache
--disaggregation-transfer-backend nixl
--log-level info 2>&1 |tee log/prefill.log

CUDA_VISIBLE_DEVICES=4,5,6,7 UCX_TLS=^cuda_ipc python3 -m sglang.launch_server
--disaggregation-ib-device mlx5_bond_4,mlx5_bond_5,mlx5_bond_6,mlx5_bond_7
--model-path /models/DeepSeek-R1-Distill-Qwen-32B
--served-model-name DeepSeek-R1-Distill-Qwen-32B
--enable-nccl-nvls
--tp-size 4
--dist-init-addr $IP:29501
--port 8082
--nnodes 1
--node-rank 0
--trust-remote-code
--host 0.0.0.0
--disaggregation-mode decode
--mem-fraction-static 0.75
--max-running-requests 16
--context-length 32768
--watchdog-timeout 3600
--page-size 64
--cuda-graph-max-bs 16
--disable-radix-cache
--disaggregation-transfer-backend nixl
--log-level info 2>&1 |tee log/decode.log

Benchmark is the same as 2 machines.

The launch server command affects the size and number of data transmissions by Nixl. By using the method described above, Nixl can transmit several hundred MB of data at a time after each prefill forward, thus improving bandwidth and stability. The desired result is that Nixl can transmit at full bandwidth.

We use NIXL telemetry data to obtain the performance difference. In our experiments with 2machine, we found that this change reduced the average transmission time of the nixl from 25,000 µs to 19,000 µs.

Copy link
Copy Markdown

@iyastreb iyastreb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM besides minor comments

kv_chunk: TransferKVChunk = queue.get()
room = kv_chunk.room
try:
if (
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this check is needed?
Under what circumstances it might be true?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is needed to handle the case where a previous chunk for the same room has already failed in this worker thread.

All chunks for the same room are routed to the same queue (bootstrap_room % len(self.transfer_queues)), so they are processed sequentially by the same worker. If chunk N fails (e.g., send_kvcache raises), the except block marks the room as KVPoll.Failed. When chunk N+1 is dequeued, without this check it would:

  1. Call self.update_status(room, KVPoll.Transferring) , overwriting the Failed status back to Transferring — breaking status monotonicity and potentially confusing the scheduler.
  2. Attempt the NIXL transfer again, which would likely fail for the same reason, wasting RDMA resources.

The continue skips all remaining chunks for the failed room. Retries (if any) are handled by the decode-side scheduler with a new room number, so they won't be affected by this check.

):
continue

if room not in self.transfer_infos:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is needed? Maybe just assert is enough?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed in 89467c9.

raise RuntimeError(f"NIXL transfer encountered ERR room={room}")
if all(s == "DONE" for s in states):
break
time.sleep(0.001)
Copy link
Copy Markdown

@iyastreb iyastreb Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested this polling with my benchmark, and I see that using time.sleep(0) gives even more boost:

# p2d4, best Mean TTFT
num_prompts  main   PR 20860
                          # with sleep(0)
128          523    562   372
256          597    485   490
512          1167   954   855
1024         2724   2350  2245

So please consider using time.sleep(0) or maybe even os.sched_yield()
How would that impact your benchmark results?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the benchmark data! We've updated to time.sleep(0) in the latest push.
We've also been experimenting with moving this polling loop into NIXL's C++ layer, which shows even further improvement by avoiding the Python overhead entirely. We may follow up on that in a subsequent PR after more testing.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

CC: @zackyoray


def poll(self) -> KVPoll:
status = self.kv_mgr.check_status(self.bootstrap_room)
if not self.has_sent:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we return status here in all cases, so we can just return directly and remove the ifs

continue

chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
if len(chunked_dst_kv_indice) < len(kv_chunk.prefill_kv_indices):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code has:

assert len(chunked_dst_kv_indice) == len(kv_indices)

Why was this changed?


if kv_chunk.is_last:
if room in self.transfer_infos:
del self.transfer_infos[room]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see multiple threads modifying and reading self.transfer_infos. Do we need locking? Or pass the data to the workers in a safer way?

@ovidiusm ovidiusm mentioned this pull request Apr 28, 2026
5 tasks
@ishandhanani
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@ovidiusm
Copy link
Copy Markdown
Contributor

Opened #23967 and rebased onto main

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants