[ROCm][PD] add moriio kv connector.#29304
Conversation
|
Documentation preview: https://vllm--29304.org.readthedocs.build/en/29304/ |
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
|
Hi @inkcherry, have you also tested with ray executor backend? |
| req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = ( | ||
| decode_instance_endpoint["dp_size"] | ||
| ) | ||
| req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = ( | ||
| decode_instance_endpoint["tp_size"] | ||
| ) |
There was a problem hiding this comment.
why does P need to know the dp and tp size of D for transfer?
There was a problem hiding this comment.
This is due to the RDMA-based push(write) mode.
If using pull (read) mode, this is not required.
There was a problem hiding this comment.
can you also provide a proxy example for the pull mode?
There was a problem hiding this comment.
@kouroshHakha do you have any other comments other than this example request?
Let's add this example in a follow up PR. In this PR @inkcherry has added an example script examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py.
There was a problem hiding this comment.
Both pull mode and push mode use a unified proxy. For pull mode, you can refer to 1P1D TP8 (PULL MODE) in the test plan of this pr.
|
|
||
|
|
||
| async def send_request_to_prefill( | ||
| endpoint, req_data, request_id, p_endpoint, pip, pports, selected_prefill_dp_rank |
There was a problem hiding this comment.
why is there a p_endpoint? Isn't that supposed to be the d_endpoint?
There was a problem hiding this comment.
yes, thanks, renamed
This PR has not been tested, but I think this shouldn't be a problem, further testing will be conducted subsequently. |
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
| except UnicodeDecodeError: | ||
| logger.warning("Received non-UTF8 message: %s", msg_str) | ||
| if not handled: | ||
| raise MoRIIOError(f"Unhandled message format: {msg_str}") |
There was a problem hiding this comment.
Undefined variable when UnicodeDecodeError is raised
High Severity
In _handle_message, when msg.decode("UTF-8") raises a UnicodeDecodeError, the variable msg_str is never assigned because the exception occurs during its assignment. However, the exception handler on line 532 and line 534 both try to use msg_str, which will cause a NameError crash. The logging statement should use msg directly or handle the case where msg_str is not defined.
🔬 Verification Test
Why verification test was not possible: This code path requires the MORI library which is only available on ROCm hardware, and requires simulating receiving a non-UTF8 message through ZMQ. The bug is apparent from static analysis - the variable msg_str is assigned inside the try block on line 527, but if that assignment raises UnicodeDecodeError, the except block tries to use msg_str which was never assigned.
|
|
||
| # In dp(prefill)<->dp(decode) communication, we require an all-to-all handshake. | ||
|
|
||
| for cur_dp_rank in range(remote_dp_size): |
There was a problem hiding this comment.
Undefined variables when handshake future already exists
High Severity
In _background_moriio_handshake, the variables host, port, tp_size, and remote_dp_size are only assigned inside the if fut is None block. However, these variables are used unconditionally on line 1036 in range(remote_dp_size) and line 1039. If a handshake future already exists for this remote_engine_id, fut will not be None, the variables won't be assigned, and the function will crash with a NameError. This can occur when a second request for the same remote engine arrives while handshaking is in progress.
🔬 Verification Test
Why verification test was not possible: This requires the MORI library on ROCm hardware and simulating a race condition where two requests for the same remote engine are processed concurrently. The bug is clear from static analysis - if the if fut is None block is not entered, remote_dp_size (and other variables) are never defined, but range(remote_dp_size) is called unconditionally afterward.
| moriio_mem_metadata | ||
| ) | ||
|
|
||
| self.local_kv_cache_size.append(cache.nelement() * cache.element_size()) |
There was a problem hiding this comment.
Wrong variable used for KV cache size calculation
Medium Severity
In register_kv_caches, line 1120 uses cache from a previous loop instead of kv_cache from the current iteration. The variable cache was last assigned in the loop on lines 1103-1109, and always points to the last cache of the last layer processed. This causes incorrect size calculations where all layers are recorded with the same size from the final cache, rather than their actual sizes. The line should use kv_cache.nelement() * kv_cache.element_size().
🔬 Verification Test
Why verification test was not possible: This requires the MORI library on ROCm hardware. The bug is clear from code inspection - the loop at line 1111 iterates over layer_name, kv_cache, but line 1120 uses cache which is from the earlier loop (lines 1103-1109).
|
|
||
|
|
||
| def example_round_robin_dp_loader(request_number, dp_size): | ||
| return request_nums % dp_size |
There was a problem hiding this comment.
Function ignores parameter, uses global instead
Medium Severity
The function example_round_robin_dp_loader accepts a parameter request_number but ignores it entirely, instead using the global variable request_nums. This makes the function parameter meaningless and the call on line 231 passes a calculated value that is never used. The function body should use request_number instead of request_nums.
🔬 Verification Test
Test code:
# Test to verify the bug
request_nums = 10 # Global
def example_round_robin_dp_loader(request_number, dp_size):
return request_nums % dp_size # Bug: uses global, not parameter
# The parameter is ignored - result is based on global `request_nums`, not `request_number`
result = example_round_robin_dp_loader(5, 3)
print(f"Called with request_number=5, dp_size=3")
print(f"Expected: 5 % 3 = 2")
print(f"Actual: {result} (because it uses global request_nums={request_nums})")
print(f"Bug confirmed: {result != 2}")Command run:
python3 -c "
request_nums = 10
def example_round_robin_dp_loader(request_number, dp_size):
return request_nums % dp_size
result = example_round_robin_dp_loader(5, 3)
print(f'Called with request_number=5, dp_size=3')
print(f'Expected: 5 % 3 = 2')
print(f'Actual: {result} (uses global request_nums={request_nums})')
print(f'Bug confirmed: {result != 2}')
"
Output:
Called with request_number=5, dp_size=3
Expected: 5 % 3 = 2
Actual: 1 (uses global request_nums=10)
Bug confirmed: True
Why this proves the bug: The function returns 1 (10 % 3) instead of 2 (5 % 3), proving it uses the global request_nums instead of the passed parameter request_number.
| target=self._write_worker_loop, daemon=True, name="moriio-write-worker" | ||
| ) | ||
| thread.start() | ||
| logger.info("Started MoRIIO write worker thread") |
There was a problem hiding this comment.
Race condition in write worker thread initialization
Medium Severity
In ensure_worker_started, there's a TOCTOU (time-of-check-time-of-use) race condition. The check if self._write_worker_started and the flag assignment self._write_worker_started = True both happen outside the lock. Multiple threads could pass the initial check simultaneously, then each would set the flag and proceed to acquire the lock and start separate worker threads. The flag should be checked and set inside the lock with a double-check pattern.
🔬 Verification Test
Why verification test was not possible: Reproducing race conditions reliably requires precise timing control and is non-deterministic. However, the bug is clear from code structure - lines 95-97 check and set the flag before acquiring the lock on line 98, creating a classic TOCTOU vulnerability.
| ) | ||
| break | ||
| else: | ||
| break |
There was a problem hiding this comment.
Infinite loop when no requests need saving
High Severity
In save_kv_layer, when metadata.reqs_to_save is empty, remote_engine_id remains None from its initialization. The while loop at line 1249 then checks remote_engine_id not in self.write_ready_flags, which is always True when remote_engine_id is None. Combined with an empty _ready_requests queue, the first condition is satisfied indefinitely, causing an infinite spin loop that will hang the entire system. Unlike start_load_kv which has a wait_handshake_readd_req guard, this function lacks such protection.
🔬 Verification Test
Why verification test was not possible: This requires the MORI library on ROCm hardware and simulating a forward pass with no KV transfer requests. The bug is clear from static analysis - when reqs_to_save is empty, remote_engine_id stays None, and the while loop condition empty() and None not in dict evaluates to True, causing infinite continue iterations.
| selected_prefill_dp_rank = None | ||
| if prefill_instance_endpoint["dp_size"] > 1: | ||
| selected_prefill_dp_rank = example_round_robin_dp_loader( | ||
| request_nums // len(prefill_instance_endpoint), |
There was a problem hiding this comment.
len() called on dictionary instead of list
Medium Severity
The expression len(prefill_instance_endpoint) is called on a dictionary (the endpoint configuration dict containing keys like "dp_size", "tp_size", etc.), not on the list of prefill instances. This returns the number of dictionary keys (around 7-8) rather than the intended number of prefill instances. The load balancing calculation is therefore incorrect. This should likely be len(prefill_instances) to get the actual instance count.
🔬 Verification Test
Test code:
# Simulating the proxy server data structures
prefill_instance_endpoint = {
"dp_size": 2,
"tp_size": 4,
"request_address": "http://10.0.0.1:20005/v1/completions",
"handshake_port": 6301,
"notify_port": 61005,
}
prefill_instances = [prefill_instance_endpoint] # List with 1 instance
request_nums = 100
# Bug: uses dict length instead of list length
result = request_nums // len(prefill_instance_endpoint)
correct_result = request_nums // len(prefill_instances)
print(f"len(prefill_instance_endpoint) = {len(prefill_instance_endpoint)} (dict keys)")
print(f"len(prefill_instances) = {len(prefill_instances)} (actual instances)")
print(f"Bug result: {result}, Correct result: {correct_result}")Command run:
python3 -c "
prefill_instance_endpoint = {'dp_size': 2, 'tp_size': 4, 'request_address': 'http://10.0.0.1:20005', 'handshake_port': 6301, 'notify_port': 61005}
prefill_instances = [prefill_instance_endpoint]
request_nums = 100
result = request_nums // len(prefill_instance_endpoint)
correct_result = request_nums // len(prefill_instances)
print(f'len(prefill_instance_endpoint) = {len(prefill_instance_endpoint)} (dict keys)')
print(f'len(prefill_instances) = {len(prefill_instances)} (actual instances)')
print(f'Bug result: {result}, Correct result: {correct_result}')
"
Output:
len(prefill_instance_endpoint) = 5 (dict keys)
len(prefill_instances) = 1 (actual instances)
Bug result: 20, Correct result: 100
Why this proves the bug: The code uses the dictionary's key count (5) instead of the list's length (1), producing completely different load balancing values.
| if new_block_ids is not None: | ||
| block_ids = new_block_ids[0] | ||
| # TODO : hybrid attn, etc | ||
| req, existing_blocks = self._reqs_need_pending_save[req_id] |
There was a problem hiding this comment.
Missing key check causes KeyError for non-transfer requests
High Severity
In build_connector_meta, the code iterates over all requests in scheduled_cached_reqs.req_ids (which includes ALL running and resumed requests), then directly accesses self._reqs_need_pending_save[req_id] without checking if the key exists. However, _reqs_need_pending_save only contains requests that have do_remote_decode=True and are in chunked prefill mode. For any regular request without KV transfer enabled, this will crash with a KeyError. The code needs to check if req_id in self._reqs_need_pending_save before accessing it.
🔬 Verification Test
Why verification test was not possible: This requires setting up a full vLLM scheduler environment with the MORI connector and having a mix of requests with and without do_remote_decode enabled. The bug is clear from static analysis - scheduled_cached_reqs.req_ids contains all cached requests from _make_cached_request_data (which processes all running_reqs and resumed_reqs), while _reqs_need_pending_save only contains the subset of requests added via update_state_after_alloc when do_remote_decode=True and they are in chunked prefill mode.
|
|
||
|
|
||
| def example_round_robin_dp_loader(request_number, dp_size): | ||
| return request_nums % dp_size |
There was a problem hiding this comment.
Function ignores parameter and uses global variable instead
Medium Severity
The function example_round_robin_dp_loader accepts a parameter request_number but completely ignores it, instead using the global variable request_nums. This makes the function's parameter meaningless and could cause incorrect round-robin behavior when the caller expects the passed value to be used.
| selected_prefill_dp_rank = None | ||
| if prefill_instance_endpoint["dp_size"] > 1: | ||
| selected_prefill_dp_rank = example_round_robin_dp_loader( | ||
| request_nums // len(prefill_instance_endpoint), |
There was a problem hiding this comment.
Using len() on dict instead of list for sizing
High Severity
The code calls len(prefill_instance_endpoint) where prefill_instance_endpoint is a single dict (one prefill instance), not the list of all instances. This returns the number of keys in the dict rather than the number of prefill instances. The intent appears to be len(prefill_instances) to distribute requests across multiple prefill instances.
| except UnicodeDecodeError: | ||
| logger.warning("Received non-UTF8 message: %s", msg_str) | ||
| if not handled: | ||
| raise MoRIIOError(f"Unhandled message format: {msg_str}") |
There was a problem hiding this comment.
Variable undefined when UnicodeDecodeError is raised
High Severity
If msg.decode("UTF-8") on line 527 raises a UnicodeDecodeError, the variable msg_str is never assigned. The except block on line 532 attempts to log msg_str, and line 534 also references it. This causes a NameError to be raised instead of the intended MoRIIOError, masking the actual problem and potentially crashing the message handling thread.
| for cur_dp_rank in range(remote_dp_size): | ||
| dp_engine_id = self.get_engine_name_with_dp(remote_engine_id, cur_dp_rank) | ||
| future = self._handshake_initiation_executor.submit( | ||
| self._moriio_handshake, host, port, tp_size, dp_engine_id, cur_dp_rank |
There was a problem hiding this comment.
Variables undefined when handshake future already exists
High Severity
The variables host, port, tp_size, and remote_dp_size are only assigned inside the if fut is None block (lines 1020-1024), but they are used unconditionally on lines 1036-1039. If a handshake future already exists in _handshake_futures (i.e., fut is not None), these variables will be undefined, causing a NameError when the code attempts to iterate range(remote_dp_size) or submit the handshake.
| target=self._write_worker_loop, daemon=True, name="moriio-write-worker" | ||
| ) | ||
| thread.start() | ||
| logger.info("Started MoRIIO write worker thread") |
There was a problem hiding this comment.
Race condition in worker thread startup check
Medium Severity
The ensure_worker_started method checks and sets _write_worker_started (lines 95-97) before acquiring _write_worker_lock. If two threads call this method concurrently, both could pass the initial check before either sets the flag to True, resulting in multiple background worker threads being started. The flag check and set should occur inside the lock.
| moriio_mem_metadata | ||
| ) | ||
|
|
||
| self.local_kv_cache_size.append(cache.nelement() * cache.element_size()) |
There was a problem hiding this comment.
Wrong variable used for cache size calculation
High Severity
The line uses cache (a leftover variable from the previous nested loop at lines 1103-1109) instead of kv_cache (the current loop variable from line 1111). This causes local_kv_cache_size to be populated with the size of the last cache from the earlier loop for every layer, rather than each layer's actual cache size. This could lead to incorrect size calculations and potential data corruption during KV cache transfers.
Is building from source a preferred way of installing this ROCm component? |
| # Install Python and other dependencies | ||
| RUN apt-get update -y \ | ||
| && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \ | ||
| && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \ |
There was a problem hiding this comment.
Are these libraries needed in runtime or just in the build phase?
Currently, building from source is the recommended method. We will update to the pip method later. |
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
|
@inkcherry Trying to get this to work, but it just seems to hang. In particular, it looks like the decode service is hanging. |
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Purpose
This PR introduces the mori-io KV connector for AMD devices. Built on top of the MORI project, the mori-io connector supports both PULL and PUSH modes for KV Cache transfer. Key features include:
Mori backend integration.
Mori-related components (buffer merge &session cache management &batch io).
PULL mode (Serial interaction of prefill and decode).
PUSH mode (Parallel interaction of prefill and decode, with non-blocking layer-wise transfer)
A unified proxy example for both push and pull logic.
xPyD
parallel strategy support
High Level Design
Push Mode Implementation Details:
For prefill, we use a dedicated thread to maintain the transfer queue. Layers are enqueued in a layer-wise manner while the forward pass proceeds normally. Once prefill receives the block allocation signal from decode, it asynchronously schedules transfers.
For decode, Requests first send their allocated blocks to the prefill stage and are then placed into a waiting queue. Once the write completion signal is received from the prefill, the request is extracted from the waiting queue and directly scheduled via continuous batching.
This approach introduces no blocking overhead between prefill and decode, and is compatible with both decode graph mode and the chunked prefill feature.
Regarding the transfer process:
TODO:
Test Plan
device: MI300
Accuracy
Performance
Accuracy launch scripts
launch proxy :
python examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.pygsm8k task: The GMS8K proxy remains consistent across all tasks. The startup commands for the instance are listed below.
1P1D TP8 or 3P2D TP8
Launch prefill Instance
Launch decode instance
1P1D DP8
Launch prefill Instance
Launch decode instance
1P1D TP8(PULL MODE)
Launch prefill instance
Launch decode instance
1P1D TP2(Qwen3-32B, non-MLA)
Launch prefill instance
Launch decode instance
Performance launch scripts
server
The performance test uses the same startup scripts as the accuracy test, with the following key parameters standardized for all performance benchmarks to ensure comparability:
bench
Test Result
Accuracy:
1P1D TP8
3P2D TP8
1P1D DP8
1P1D TP8(PULL MODE, fp8 kv cache dtype)
1P1D TP2(Qwen3-32B, non-MLA)
Performance:
1P1D TP8(PULL/PUSH MODE)
TTFT
Prerequisites
mori-io and mori-ep have the same dependencies, mori-ep: [AMD][ROCm] MoRI EP: a high-performance all2all backend #28664
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Note
Cursor Bugbot is generating a summary for commit 7164614. Configure here.
Note
Cursor Bugbot is generating a summary for commit 7e1d978. Configure here.
Note
Enables RDMA-based disaggregated KV transfer on ROCm via MORI.
MoRIIOConnector(scheduler/worker) undervllm/distributed/kv_transfer/kv_connector/v1/moriio/*with ZMQ handshakes, notify flow, READ/PUSH modes, TP/DP support, and session/offset managementkv_connector/factory.pyand adds env toggles (VLLM_MORIIO_*) invllm/envs.pyexamples/online_serving/disaggregated_serving/moriio_toy_proxy_server.pyfor prefill/decode service discovery and request routingdocker/Dockerfile.rocm_baseto build/install MORI (new args/envs, deps) and bundle its wheeltests/v1/kv_connector/unit/test_moriio_connector.pyWritten by Cursor Bugbot for commit 7e1d978. This will update automatically on new commits. Configure here.