Skip to content

[KVConnector] Support 3FS KVConnector#37636

Merged
simon-mo merged 6 commits into
vllm-project:mainfrom
ibifrost:feature/support_hf3fs_connector_opensource
Apr 7, 2026
Merged

[KVConnector] Support 3FS KVConnector#37636
simon-mo merged 6 commits into
vllm-project:mainfrom
ibifrost:feature/support_hf3fs_connector_opensource

Conversation

@ibifrost

@ibifrost ibifrost commented Mar 20, 2026

Copy link
Copy Markdown
Contributor

Overview

This PR introduces the implementation of the 3FS KVConnector for vLLM.

The 3FS KVConnector enables efficient offloading and sharing of KV caches across nodes, significantly accelerating long-context inference scenarios. Alongside the core implementation, we provide the 3FS Operator for one-click deployment and a mini3fs setup for easy local verification.

Deployment Operator: aliyun/kvc-3fs-operator
Mini3FS Quick Start: Mini3FS Deployment Guide

Performance Result

We conducted test to evaluate performance in Long-Context QA scenarios.

  1. Dataset:Loogle DataSet
    100 groups of prompts.
    Each group: 20K shared prefix + 20 queries.
    Query structure: 20K (prefix) + 128 (input) tokens → 64 (output) tokens.
  2. Model: Qwen3-Coder 480B
  3. Hardware: 8x NVIDIA H20-3e GPUs

Results
Comparing vLLM with 3FS Connector vs. vLLM with L1 Only:

image

Note: The "3FS Cold Start" scenario means allowing new nodes to instantly access KV caches generated by other nodes without recomputation or slow network transfer.

cc @ApostaC @KuntaiDu

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the HF3FS KVConnector, a significant feature for efficient KV cache offloading. The implementation is extensive, covering the connector itself, a client, a metadata server, and various utilities. Overall, the design is solid, leveraging asynchronous I/O and modern Python features. However, I've identified a few critical issues, including a logic bug in input validation, an error in metrics collection that would lead to a runtime crash, and a couple of instances of duplicated code that should be cleaned up for better maintainability. Addressing these points will improve the robustness and quality of this new connector.

Comment on lines +219 to +225
all(
[
offset < 0 or offset + size > self.size
for offset, size in zip(offsets, sizes)
]
),
all([size > self.bytes_per_page for size in sizes]),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The validation logic here is incorrect. Using all() means an error is raised only if all offsets are out of bounds or all tensor sizes exceed the page size. This should be any(), so that an error is raised if any item is invalid. This bug could lead to out-of-bounds memory access or silent data corruption.

                any(offset < 0 or offset + size > self.size
                    for offset, size in zip(offsets, sizes)),
                any(size > self.bytes_per_page for size in sizes),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gemini's comment seems to be correct. Can you check @ibifrost?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using all() here was indeed a logic error, I've fix it

Comment on lines +1147 to +1148
for list_item in transfer_stats_data[counter_item_key]:
counter_obj[engine_idx].inc(list_item)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The observe method incorrectly attempts to iterate over num_failed_save and num_failed_load from transfer_stats_data. These values are integers, not iterables, which will cause a TypeError at runtime when metrics are being observed. The loop should be removed, and inc() should be called directly with the integer value.

            counter_obj[engine_idx].inc(transfer_stats_data[counter_item_key])

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also fixed

Comment on lines +322 to +328

# When EPLB is enabled, redundant physical expert slots may map to
# logical experts that belong to other ranks in the default partition.
# The weight loader needs to see ALL logical expert weights so it can
# populate these redundant slots. Skip the filter entirely.
if parallel_config.enable_eplb:
return

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code, which skips the expert weight filter when EPLB is enabled, is a duplicate of the code block that immediately follows it. This duplication should be removed to improve code clarity and maintainability.

Comment on lines +226 to +256
# Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
# q is 4D: (batch, q_len_per_req, num_heads, head_dim)
# FlashInfer has a bug where out= validation hardcodes 3D shape
# (batch, num_heads, kv_lora_rank), but the kernel writes 4D
# (batch, q_len, num_heads, kv_lora_rank) when q_len > 1.
# So we can only pass out= for single-token decode (q_len == 1).
# For q_len > 1, we zero padding slots after the kernel returns.
# TODO: upstream fix to FlashInfer
B, q_len_per_req = q.shape[0], q.shape[1]
out_kwargs: dict[str, torch.Tensor] = {}
if q_len_per_req == 1:
dtype = (
torch.bfloat16
if is_quantized_kv_cache(self.kv_cache_dtype)
else q.dtype
)
if (
self._decode_out is None
or self._decode_out.shape[0] < B
or self._decode_out.dtype != dtype
):
self._decode_out = torch.zeros(
B,
q.shape[2],
self.kv_lora_rank,
dtype=dtype,
device=q.device,
)
out_kwargs["out"] = self._decode_out[:B]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code, which prepares the out_kwargs for the FlashInfer kernel, is a duplicate of the code block at lines 195-224. This redundant code should be removed to improve maintainability.

@ibifrost ibifrost force-pushed the feature/support_hf3fs_connector_opensource branch from 8d38bb8 to c451357 Compare March 20, 2026 03:45
@ibifrost ibifrost changed the title Feature/support hf3fs connector opensource [KVConnector] Support 3FS KVConnector Mar 20, 2026
@ibifrost ibifrost force-pushed the feature/support_hf3fs_connector_opensource branch from c451357 to 786ec6c Compare March 20, 2026 03:54
Comment thread pyproject.toml Outdated
@@ -48,6 +48,11 @@ lora_hf_hub_resolver = "vllm.plugins.lora_resolvers.hf_hub_resolver:register_hf_
[tool.setuptools_scm]
# no extra settings needed, presence enables setuptools-scm

[tool.setuptools.package-data]
"vllm" = [

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A naive question: why we need to define extra package-data here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We rely on runtime compilation via
"
from torch.utils.cpp_extension import load
...
hf3fs_utils = load(
name="hf3fs_utils",
sources=[f"{root}/utils/hf3fs_utils.cpp"],
extra_include_paths=[cuda_include_path],
)
"
in hf3fs_client.py now, that's why the hf3fs_utils.cpp file should be included into pkg here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simon-mo Just wondering is it a common practice to introduce the c/c++ sources like this? Or there are some other common practices?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

package data is fine. but we put those in setup.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I've move the logic to setup.py as suggested.

@ApostaC ApostaC left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a more detailed pass. There seems to be some small problems.
Please add some unit tests as well. But overall, the structure looks good to me. Thanks for the terrific job!

Comment on lines +219 to +225
all(
[
offset < 0 or offset + size > self.size
for offset, size in zip(offsets, sizes)
]
),
all([size > self.bytes_per_page for size in sizes]),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gemini's comment seems to be correct. Can you check @ibifrost?

return self._fail_task("Saved", "Write operation failed", request_id, future)

except Exception as e:
return self._fail_task("Saved", f"Task execution error: {e}", request_id, future)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call free_buffer before self._fail_task?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to release the buffers in error paths. I've added a check before calling _fail_task to ensure resources are cleaned up.

if key in self.key_metadata:
key_meta = self.key_metadata[key]
if key_meta.is_complete() and rank in key_meta.rank_to_page:
allocation_results[key] = key_meta.rank_to_page[rank]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to free the existing allocation_results[key] before overwriting it? I'm not so sure about the logic here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, just to clarify the logic here:
When a key is already complete (key_meta.is_complete()), we reuse its existing page from metadata, so we do not need to free that existing page because it's still valid and in use.

However, since we pre-allocated a new page for every key at the start of the function, hitting this branch means we have a newly allocated page that we won't actually use. The original code failed to release this specific unused pre-allocated page, which could lead to a potential leak.

So here need to add logic to keep the existing mapping but release the unused allocated_pages, I'll add a test case to ensure the free page count stays consistent when reusing existing keys.

Comment on lines +240 to +248
"""Close the client and clean up resources."""
deregister_fd(self.file)
os.close(self.file)
del self.ior_r
del self.ior_w
del self.iov_r
del self.iov_w
self.shm_r.close()
self.shm_w.close()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do any check to prevent double-close or double-del?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I've moved the cleanup logic into a _release_resources method and added checks to ensures the close operation is idempotent.

@ibifrost ibifrost force-pushed the feature/support_hf3fs_connector_opensource branch 2 times, most recently from db51247 to 3e16a87 Compare March 23, 2026 07:47
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
@ibifrost ibifrost force-pushed the feature/support_hf3fs_connector_opensource branch from 3e16a87 to f2f5f88 Compare March 24, 2026 03:49
@mergify mergify Bot added the ci/build label Mar 24, 2026

@ApostaC ApostaC left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Apr 2, 2026
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
@KuntaiDu KuntaiDu added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 3, 2026
@simon-mo simon-mo enabled auto-merge (squash) April 6, 2026 19:33
@mergify

mergify Bot commented Apr 6, 2026

Copy link
Copy Markdown
Contributor

Hi @ibifrost, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@simon-mo

simon-mo commented Apr 6, 2026

Copy link
Copy Markdown
Collaborator

@ibifrost can you fix the format?

auto-merge was automatically disabled April 7, 2026 08:02

Head branch was pushed to by a user without write access

@ibifrost ibifrost force-pushed the feature/support_hf3fs_connector_opensource branch from 5c200aa to 0d61dc6 Compare April 7, 2026 08:06
@ibifrost

ibifrost commented Apr 7, 2026

Copy link
Copy Markdown
Contributor Author

@ibifrost can you fix the format?

Sure

@ibifrost ibifrost closed this Apr 7, 2026
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Apr 7, 2026
@ibifrost ibifrost reopened this Apr 7, 2026
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
@ibifrost ibifrost force-pushed the feature/support_hf3fs_connector_opensource branch from ce544f2 to 9ab5961 Compare April 7, 2026 09:56
@simon-mo simon-mo enabled auto-merge (squash) April 7, 2026 15:31
@simon-mo simon-mo merged commit 96b5004 into vllm-project:main Apr 7, 2026
140 checks passed
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
huangyibo pushed a commit to huangyibo/vllm that referenced this pull request May 21, 2026
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
huangyibo pushed a commit to huangyibo/vllm that referenced this pull request Jun 4, 2026
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build kv-connector nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants