feat: add coordinated checkpoint prefetch for network filesystem loading#20843
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial performance enhancement for loading large language model weights, especially in distributed environments utilizing network filesystems. By intelligently prefetching checkpoint files into the OS page cache, it eliminates the substantial overhead caused by multiple processes independently accessing the same data over the network. This optimization dramatically accelerates model initialization, making the system more robust and efficient for high-performance computing setups. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces a coordinated checkpoint prefetch mechanism to significantly improve weight loading times, especially on network filesystems. The implementation correctly distributes file prefetching across ranks and includes detailed logging and performance metrics. Unit tests have been added to verify the correctness of the prefetching logic and ensure bit-identical weights with and without prefetch. The changes are well-structured and address a critical performance bottleneck.
|
@Fridge003 Could you assign to the PICs of this part? thanks! |
|
Wonderful feature! |
When multiple DP ranks on the same node load the same checkpoint via mmap (e.g. DP-attention), each rank independently page-faults all files over NFS/Lustre. With N ranks this causes N * checkpoint_size bytes of redundant network I/O. This adds a `--weight-loader-prefetch-checkpoints` flag that, before loading, distributes a sequential read of the checkpoint files across all ranks (each reads 1/Nth of the shards into the shared OS page cache). Subsequent mmap accesses then hit warm page cache instead of the network filesystem. Measured on DeepSeek R1-671B FP8 with 8 DP-attention ranks: | Setup | B300 weight load time | |--------------------------|-----------------------| | Baseline (mmap) | deadlock / 40 min | | max_workers=4 | 2390s (40 min) | | max_workers=4 + prefetch | 236s (3.9 min) | 10x improvement on Lustre, with no infrastructure changes required. The implementation: - `_prefetch_checkpoint_file()`: reads a file in 16 MB blocks to warm the page cache - `_prefetch_all_checkpoints()`: distributes files across ranks with a barrier, uses 4 threads per rank for concurrent prefetch - Wired through both `safetensors_weights_iterator` and `buffered_multi_thread_safetensors_weights_iterator`
Tests verify: - _prefetch_checkpoint_file reads every byte of the file - _prefetch_all_checkpoints only reads the files assigned to the current rank (rank::world_size partitioning) - Single-rank mode prefetches all files - Weights loaded with prefetch=True are bit-identical to prefetch=False
- Extract block_size as module-level constant _PREFETCH_BLOCK_SIZE - Make prefetch num_threads configurable via --weight-loader-prefetch-num-threads (default: 4), wired through server_args -> loader -> weight_utils
Adopt vLLM's approach (vllm-project/vllm#36012): run prefetch in a background daemon thread instead of blocking with a barrier. This pipelines I/O with loading — the loader benefits from pages already cached while the prefetch thread continues reading ahead. Advantages over the blocking approach: - Works well even when checkpoint > available RAM (sliding window) - No barrier overhead between ranks - Loading starts immediately instead of waiting for full prefetch
- Make _PREFETCH_BLOCK_SIZE configurable via SGLANG_PREFETCH_BLOCK_SIZE_MB env var - Print prefetch log messages on every rank (not just rank 0) - Add docs for --weight-loader-prefetch-checkpoints and --weight-loader-prefetch-num-threads to docs/advanced_features/server_arguments.md - Simplify tests: keep only TestPrefetchWeightsIdentical, register with register_cpu_ci(est_time=5, suite="stage-a-test-cpu") - Remove TestPrefetchReadsAllBytes and TestPrefetchDistributedOnlyReadsSubset per reviewer request (multi-GPU integration test needed separately)
f5a9b3b to
2d51eea
Compare
Two fixes: - Use get_world_group().local_rank for prefetch file distribution instead of global rank. Page cache is per-node, so each node must independently prefetch the full checkpoint. With global rank, multi- node setups would split files across nodes whose page caches are not shared. - Fix buffered_multi_thread_safetensors_weights_iterator to iterate files in sorted order (matching prefetch order), not original order.
Launches a small MoE model (Qwen1.5-MoE-A2.7B-Chat) with DP-attention and --weight-loader-prefetch-checkpoints on 4 GPUs. Verifies the server starts and produces valid output. Registered as nightly-4-gpu with est_time=300.
|
@janbernloehr could you let us know when this PR is ready for another round of review? |
|
@nvpohanh Thanks for the ping, I addressed everything. If there is anything else I can do please let me know. |
|
@Fridge003 Could you review again? Thanks! |
|
/tag-and-rerun-ci |
Closes #20842
Motivation
When multiple DP ranks on the same node load the same checkpoint via mmap (e.g. DP-attention), each rank independently page-faults all safetensors files over NFS/Lustre. With N ranks this causes N × checkpoint_size bytes of redundant network I/O.
Note: this is orthogonal to #20332, which addresses the per-rank access pattern problem with TP (striped reads within safetensors). This PR addresses the cross-rank duplication problem with DP (N ranks redundantly reading the same files over NFS). The two optimizations are complementary.
Modifications
weight_utils.py: Added_PREFETCH_BLOCK_SIZEmodule constant,_prefetch_checkpoint_file()(reads a file in blocks to warm page cache) and_prefetch_all_checkpoints()(distributes files across ranks viasorted_files[rank::world_size], runs in a background daemon thread with configurable thread count). Addedprefetchandprefetch_num_threadsparameters tosafetensors_weights_iterator()andbuffered_multi_thread_safetensors_weights_iterator().server_args.py: Added--weight-loader-prefetch-checkpointsflag and--weight-loader-prefetch-num-threads(default: 4).loader.py: Wired the new flags through to both weight iterator functions.test_prefetch_checkpoints.py: Unit tests verifying byte-level read completeness, correct rank-based file partitioning, and bit-identical weights with/without prefetch.The prefetch runs in a background daemon thread (inspired by vllm-project/vllm#36012) that pipelines I/O with loading — loading starts immediately while the prefetch thread reads ahead into the OS page cache. Each rank prefetches 1/Nth of the checkpoint shards, reducing total NFS I/O from Ncheckpoint to 1checkpoint.
Accuracy Tests
No changes to model forward code or kernels. The prefetch only affects the I/O path during weight loading — tensors loaded are identical (verified by unit test
test_weights_match_with_and_without_prefetch).Benchmarking and Profiling
DeepSeek R1-671B FP8 on Lustre, multiple platforms and configurations:
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci