Skip to content

[Dev] Paged Stashing#2690

Merged
hxbai merged 74 commits into
NVIDIA:devfrom
vasunvidia:paged_offloading
Apr 13, 2026
Merged

[Dev] Paged Stashing#2690
hxbai merged 74 commits into
NVIDIA:devfrom
vasunvidia:paged_offloading

Conversation

@nanz-nv

@nanz-nv nanz-nv commented Dec 17, 2025

Copy link
Copy Markdown
Contributor

Main PR: #4247

Main contributors (Equal Contribution, sorted alphabetically): Nan Zheng (@nanz-nv), Vasudevan Rengasamy (@vasunvidia)
Other contributors (sorted alphabetically): Dennis Liu(@Victarry), Hongbin Liu(@lhb8125), Qi Zhang(@QiZhangNV), Robin Zhang(@buptzyb), Tong Liu(@Autumn1998), Zijie Yan(@yanring)

Background

In token-dropless MoE training, the number of tokens received by each expert might vary, resulting in dynamic shaped tensors. Dynamic shaped tensors are naturally supported by PyTorch, thanks to its eager mode nature. This is done by creating a tensor lazily when the shape of the tensor is known at run-time. Albeit working well in eager mode, dynamic shaped tensor poses challenges for CUDA graphs because the the size of a tensor cannot be dynamically adjusted at runtime without the intervene of the host. In order to remove the sync and enable CUDA graph, one solution is to oversize the buffer in the expert part. This however causes significantly higher memory consumption compared to the eager-mode baseline through the form of memory fragmentation.

image

Idea overview

To address this problem, paged stashing decouples the need of oversized buffers for compute and the need of a properly sized buffer for storing activations for the backward pass. Paged stashing achieves this through adding one level of indirection: stashing and restoring. The stash operation copies the activation from the oversized static buffer to a pre-allocated stashing buffer after the forward for that module is done, and the restore operation does the reverse operation during the backward pass.

image

The key of saving memory lies in the fact that the stash operation packs the variable-size activation into a contiguous stashing buffer to reduce memory fragmentation. For simple scheduling where the activation allocation and deallocation follows a first-in-last-out pattern, stash and restore can be done easily in a bump-allocation manner. To accommodate complicated scheduling schedules, e.g. pipeline parallel, paging can be used, hence the name paged stashing.

page management

To accomodate complex scheduling such as that needed in pipeline parallelism, activations are partitioned into pages and a light-weight memory management kernel is in charge of allocate and deallocate pages for stashing. Pages are managed by lightweight GPU memory management kernels that can be fused with the stash/restore GPU kernels. It maintains a freelist which is implemented as a circular buffer. Each freelist keeps track of one type of pages.

CPU offloading

Paged stashing naturally supports offloading. When the stashing buffer is a pinned CPU tensor, the activation is offloaded to the host memory during forward and is reloaded to the GPU during backward.
Furthermore, one can easily extend the paging management system to accommodate partial offloading or on-demand offloading. This feature is currently WIP.

scheduling

Overlapping stashing and restore operations with compute can be implemented by inserting two autograd functions before and after the expert compute layer: pre-scheduler and post-scheduler that schedules stash and restore operations. The roles of these autograd functions are enumerated below:

  • Pre-scheduler forward: Wait for previous stash op. to complete, free the max-capacity sized temporary activations for the completed stash op. The wait is performed here instead of Post-scheduler forward to reduce the peak memory usage since the following expert compute layer will allocate another set of max-capacity sized temporary activations.
  • Post-scheduler forward: Since this is after experts compute, stashing operations for the current layer activations are scheduled here. If the next layer in the execution is a backward pass layer, schedule restore operations for the next layer.
    Additionally, in case of pipeline parallelism, this can be used to record the pipeline schedule during the first iteration.
  • Post-scheduler backward: Wait for previous stash op. to complete, free the max-capacity sized temporary activations for the completed stash op. The wait is performed here instead of Pre-scheduler backward to reduce the peak memory usage since the following expert compute BPROP layer will allocate another set of max-capacity sized temporary activations.
    Wait for restore operation for the current layer to complete. Additionally, in case of pipeline parallelism, this can be used to record the pipeline schedule during the first iteration.
  • Pre-scheduler backward: If the next layer in the execution is a backward pass layer, schedule restore operations for the next layer.

@copy-pr-bot

copy-pr-bot Bot commented Dec 17, 2025

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Victarry

Copy link
Copy Markdown
Contributor

/ok to test 3e8c042

@github-actions

Copy link
Copy Markdown
Contributor

Thank you for your contribution!

NVIDIA Megatron-LM is currently transitioning to development on Github. We will aim to review your PR after we complete our transition and stabilize our Github development process.

Thank you for your understanding.

Comment thread megatron/core/transformer/moe/experts.py Outdated
Comment thread megatron/core/transformer/moe/experts.py Outdated
Comment thread megatron/core/transformer/transformer_config.py Outdated
Comment thread megatron/training/arguments.py Outdated
Comment thread megatron/training/utils.py Outdated
Comment thread megatron/core/transformer/transformer_config.py Outdated
Comment thread megatron/core/transformer/moe/experts.py Outdated
Comment thread megatron/core/transformer/moe/token_dispatcher.py
Comment thread megatron/training/utils.py Outdated
Comment thread megatron/core/transformer/moe/paged_stash.py Outdated
@nanz-nv nanz-nv force-pushed the paged_offloading branch 3 times, most recently from 3cd7a47 to b5b19b0 Compare March 23, 2026 05:46
@yanring yanring requested a review from buptzyb March 24, 2026 07:20
Comment thread megatron/training/arguments.py Outdated
@hxbai

hxbai commented Apr 13, 2026

Copy link
Copy Markdown
Contributor

/ok to test 1c9b6aa

@nanz-nv

nanz-nv commented Apr 13, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 807f963

@hxbai hxbai added this pull request to the merge queue Apr 13, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24329522123

Merged via the queue into NVIDIA:dev with commit c0c4fdc Apr 13, 2026
62 checks passed
cursor Bot pushed a commit to AMD-AGI/Primus that referenced this pull request Apr 17, 2026
…dule

This change completes the paged-stashing feature work:

1. Ports NVIDIA/Megatron-LM PR #2690 (Paged Stashing) as a runtime patch so
   that the feature can be enabled on top of stock Megatron-LM without
   keeping a fork of third_party/Megatron-LM.

   - Adds primus/backends/megatron/core/transformer/moe/paged_stash.py, a
     copy of the upstream paged_stash.py module that is installed into
     megatron.core.transformer.moe.paged_stash at runtime.
   - Adds primus/backends/megatron/patches/moe_patches/paged_stash_patches.py,
     a before_train patch (gated on --moe_paged_stash) which wires all the
     integration points: TransformerConfig fields, FullCudaGraphWrapper
     extensions, _HybridEPManager over-budget tracking, MoEFlexTokenDispatcher
     check/reset_over_budget helpers, TEGroupedMLP.forward paged-stash
     context, pipeline-schedule paged_stash_reset calls,
     GPTModel.preprocess_for_paged_stash, and PagedStashRunner injection via
     get_forward_backward_func.

2. Resets third_party/Megatron-LM to d3528a2 (Primus main baseline) and
   redirects the submodule back to NVIDIA/Megatron-LM upstream, matching the
   Primus main branch configuration.

Reference: NVIDIA/Megatron-LM#2690

Co-authored-by: zhenhuang12 <zhenhuang12@users.noreply.github.com>
nanz-nv added a commit to vasunvidia/Megatron-LM that referenced this pull request May 19, 2026
Co-authored-by: Qi Zhang <qizhang@nvidia.com>
Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: a <a>
Co-authored-by: tongliu <tongliu@nvidia.com>
nanz-nv added a commit to vasunvidia/Megatron-LM that referenced this pull request May 20, 2026
Co-authored-by: Qi Zhang <qizhang@nvidia.com>
Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: a <a>
Co-authored-by: tongliu <tongliu@nvidia.com>
@jiemingz jiemingz mentioned this pull request Jun 4, 2026
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.