Skip to content

feat: NCCL Xfer based refit integration#2413

Draft
youngeunkwon0405 wants to merge 29 commits into
mainfrom
youngeunk/new-refit-integration
Draft

feat: NCCL Xfer based refit integration#2413
youngeunkwon0405 wants to merge 29 commits into
mainfrom
youngeunk/new-refit-integration

Conversation

@youngeunkwon0405

@youngeunkwon0405 youngeunkwon0405 commented May 5, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Current test results

Test Backend Mapping Refit time KL range
4b Megatron → vLLM DP8 → DP8 0.07s–0.08s 0.0004–0.0009
4b Megatron → vLLM DP8 → TP8 0.07s–0.08s 0.0005–0.0007
4b Megatron → vLLM PP2×DP4 → TP4×DP2 0.09s 0.0004–0.0007
4b Megatron → vLLM PP4×DP2 → TP4×DP2 0.13s–0.15s 0.0004–0.0007
4b Megatron → vLLM TP2×PP2×DP2 → TP4×DP2 0.13s 0.0005–0.0012
4b Megatron → vLLM TP2×DP4 → TP4×DP2 0.11s–0.15s 0.0005–0.0007
30b Megatron → vLLM EP4×TP2×DP2 → TP4×DP4 0.54s–0.67s 0.0016–0.0031
30b Megatron → vLLM EP4×TP2×PP2 → TP4×DP4 0.48s–0.52s 0.0017–0.0033
30b Megatron → vLLM EP8×PP2 → TP2×DP8 0.51s–0.55s 0.0016–0.0029
30b Megatron → vLLM EP8×PP2 → TP8×DP2 0.50s–0.51s 0.0016–0.0031
nano-v3 Megatron → vLLM EP8×TP2 → TP4×DP4 0.48s–0.55s 0.0011–0.0029
super-v3 Megatron → vLLM EP8×TP8xPP2xDP2 → TP8×DP2 1.66s–1.94s 0.0042–0.0081
235b Megatron → vLLM TP4×PP8×EP16 → TP8×DP16 4.56s–4.85s 0.0041–0.0076
dsv3 Megatron → vLLM PP16×EP16 → TP32×DP8 7.65s–9.94s 0.0014–0.0024
4b (FP8) Megatron → vLLM DP8 → DP8 0.06s–0.07s 0.0028–0.0045
4b (FP8) Megatron → vLLM DP8 → TP4×DP2 0.05s–0.10s 0.0029–0.0046
4b (FP8) Megatron → vLLM PP2×DP4 → TP4×DP2 0.06s–0.10s 0.0029–0.0042
4b (FP8) Megatron → vLLM PP4×DP2 → TP4×DP2 0.08s–0.10s 0.0031–0.0041
4b (FP8) Megatron → vLLM TP2×DP4 → TP4×DP2 0.09s–0.15s 0.0030–0.0050
30b (FP8) Megatron → vLLM EP4×TP2×PP2 → TP2×DP8 1.54s–1.69s 0.0056–0.0107
30b (FP8) Megatron → vLLM EP8×PP2 → TP2×DP8 1.34s–1.60s 0.0056–0.0101
30b (FP8) Megatron → vLLM EP8×DP2 → TP2×DP8 1.21s–1.62s 0.0054–0.0096

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

fix DTensor nccl_reshard_refit dtype mismatch
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

code refactoring

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

code refactoring

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 requested review from a team as code owners May 5, 2026 20:19
@copy-pr-bot

copy-pr-bot Bot commented May 5, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@youngeunkwon0405 youngeunkwon0405 marked this pull request as draft May 5, 2026 20:19
Build the fused w13_weight in train-side _fuse_expert_params with a
TP-aware permute so contiguous Shard(1) slicing yields vLLM's expected
per-rank [half_gate, half_up] layout (per fused_moe/layer.py::_load_w13).

Drops 30B EP8PP2 refit Generation KL from ~5.2/5.7 to 0.0018, matching
the dense baseline (~0.0004).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405

Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread nemo_rl/distributed/nccl_reshard_utils.py Outdated
Comment thread nemo_rl/algorithms/grpo.py Outdated
@claude

claude Bot commented May 6, 2026

Copy link
Copy Markdown

Review notes

Missing exemplar YAML update: The new nccl_reshard_refit config key was added to PolicyConfig (TypedDict) but is not reflected in any exemplar YAML under examples/configs/*.yaml. Per config conventions, new config keys should be documented and defaulted in the exemplar configs.

Missing test coverage: nemo_rl/distributed/nccl_reshard_utils.py is a 739-line new file with significant pure logic (placement rules, shard-slice computation, MoE expert fusion, mesh construction) that is highly amenable to unit testing. Consider adding tests for at least get_tp_shard_dim, _compute_shard_slices, fuse_expert_params_in_metadata, and build_nccl_reshard_refit_info.

youngeunkwon0405 and others added 3 commits May 5, 2026 18:54
…ridge

Wires up FP8→FP8 weight transfer through Bridge's `build_export_fp8_tasks`
so the train side ships TE Float8BlockwiseQTensor data + scale_inv pairs
without recreating ~1000 lines of TE-coupled extraction logic ourselves.

Train side (megatron_policy_worker.py):
- `_is_fp8_export()` returns True iff fp8_cfg has fp8_param=true and
  fp8_recipe="blockwise" (the only Megatron config TE produces FP8
  storage for).
- `_build_refit_conversion_tasks()` routes between
  `bridge.get_conversion_tasks` (BF16 / fp8_param=false) and
  `bridge._model_bridge.build_export_fp8_tasks` (fp8_param=true). The
  FP8 path emits paired (FP8 weight, scale_inv) tasks.
- `_iter_local_hf_params` recognises Bridge's `_HFNameSuffixMapping`
  scale_inv tasks. For compound mappings the scale tensor is split the
  same way the weight is:
  * QKVMapping → reuse `split_qkv_weights` (Bridge auto-detects
    scale-domain trailing-dim block compression)
  * GatedMLPMapping → chunk by 2 along dim 0 (gate / up halves)
  * simple mapping → yield with `_scale_inv` HF-name suffix
- `_calculate_refit_param_info`: extend `prec_to_bytes` with
  `float8_e4m3fn`, `float8_e5m2`, `uint8` so the size accounting works
  on FP8 weights.

Gen side (vllm_backend.py):
- `_build_hf_to_vllm_mapping` MERGE_RULES extended with `*_scale_inv`
  variants so HF q/k/v_proj.weight_scale_inv merge into vLLM
  qkv_proj.weight_scale_inv (same for gate/up). Direct-match
  scale_inv (e.g. o_proj.weight_scale_inv) flows through the existing
  1:1 path. Existing `_compute_tp_local_slice` handles dim-0 sharding
  for the blockwise scale shape `[output//block, input//block]`.

Test scripts (script/new_refit/):
- `4b_fp8_to_fp8_dp_dp.sh` — Qwen3-4B dense, fp8_param=true train →
  vLLM precision=fp8 use_deep_gemm=true. NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
  set as direct shell export (Hydra rejects new env_vars keys without
  `+` and rejects non-string values without escape engineering that
  doesn't survive bash double-quoting).
- `4b_bf16_to_fp8_dp_dp.sh` — BF16 train → vLLM FP8. Documented as
  failing today: needs gen-side quantization in finalize.

Verification (FP8→FP8, job 11586681): 10/10 GRPO steps complete,
initial refit + per-step refits all succeed, Generation KL
0.0029-0.0045 (BF16 baseline ~0.0004; ~10× higher is consistent with
FP8 quantization noise), weight_sync ~0.07s/step, clean exit.

Existing BF16 path verified non-regressive (4b TP2→TP4 KL 0.0004).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
- get_placements: skip TP shard for expert params when EP>1 (current
  scope: expert_tensor_parallel_size=1 only); use len(dim_map) for
  mesh dim count.
- Remove vestigial fields _moe_fused, _moe_num_experts_local,
  _moe_gate_entries, _moe_up_entries, _moe_down_entries; replace with a
  single _moe_kind: "w13" | "w2" discriminator.
- Drop metadata_ep_gathered flag and ep_size parameter; the EP-gathered
  invariant on input metadata is now asserted in the docstring (Megatron
  meets it via export_hf_weights' EP all-gather; DTensor doesn't use EP).

Net diff: -64 / +34 lines.

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 force-pushed the youngeunk/new-refit-integration branch from dad5cb6 to a543fb2 Compare May 11, 2026 02:37
youngeunkwon0405 and others added 4 commits May 11, 2026 14:39
Makes the nccl_reshard-based refit path work correctly and within memory
for Qwen3-MoE training configs that combine expert and tensor parallelism
(e.g. EP4×TP2).  Three independent issues:

1. OOM during metadata enumeration.  Bridge's export_hf_weights does
   PP/TP/EP gathers to materialize full unsharded tensors, but the
   refit-info builders only need shape+dtype.  Wrap the enumeration in
   a context that redirects empty_like/zeros_like to "meta" and turns
   the collectives into no-ops; peak memory stays at zero extra GiB.
   Lifted to a module-level helper so prepare_refit_info (BCAST path)
   and prepare_nccl_reshard_refit_info share it.

2. OOM during expert fusion.  Previously all layers' fused w13/w2
   tensors were materialized upfront (~14 GiB peak on 30B).  Rebuild
   one fused tensor at a time inside the transfer loop and free
   immediately after; peak drops to one fused param.

3. KL divergence on ETP=1 EP×TP configs.  Megatron-Core with
   expert_tensor_parallel_size=1 gives non-expert and expert params
   *different* rank-to-coord layouts on the same physical ranks
   (e.g. EP4×TP2 has tp=r%2 AND ep=r%4 on the same rank — not a
   single product structure).  build_nccl_reshard_refit_info now
   constructs two src meshes per PP stage — one for non-expert params
   (tp×dp) and one for expert params (ep×edp) — and selects per-param
   based on is_expert_param.

Verified on the full 11-config test matrix.  KL on previously-broken
configs:
- 30b EP4TP2→TP4:    1.46 → 0.002
- 30b EP4TP2PP2→TP4: 1.39 → 0.002
Other configs (1b/4b/30b EP8) unchanged within float tolerance.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
These per-user SLURM submission scripts are local conveniences and shouldn't
live in the repo.  Keep them on disk (still gitignore-able) but stop tracking.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
`_param_map` (HF-name → Megatron-param view), `_expert_groups` (regex
grouping of expert params by layer/proj), and `_hf_to_vllm` (HF→vLLM
merge-rule mapping) all depend only on model topology, not weight
values.  Build them once in the respective prepare_nccl_reshard_refit_info
calls and reuse for every refit step.  Saves one full Megatron-param
iterator pass + one regex scan + one vLLM named_parameters walk per
refit step.

Also: dedup the per-expert regex (reuse _INDIVIDUAL_EXPERT_RE from
nccl_reshard_utils), drop leftover bringup print()s in
init_pp_comm_groups, and remove a WHAT-narrating comment.

Verified on the full L1 matrix (11 small + dsv3 + qwen3-235b); KL
values unchanged within run-to-run variance.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 force-pushed the youngeunk/new-refit-integration branch from e8a3919 to 96613c3 Compare May 11, 2026 21:43
@youngeunkwon0405 youngeunkwon0405 self-assigned this May 14, 2026
@youngeunkwon0405 youngeunkwon0405 force-pushed the youngeunk/new-refit-integration branch from 946394e to 181457c Compare June 3, 2026 17:13
@youngeunkwon0405

Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread nemo_rl/models/generation/vllm/vllm_backend.py
Comment thread nemo_rl/distributed/nccl_xfer_utils.py
Comment thread nemo_rl/models/megatron/data.py
youngeunkwon0405 and others added 8 commits June 9, 2026 22:49
…c paths

Extend the nccl_reshard (xferdtensor) disaggregated refit to FP8 weights
with a hybrid transfer:

- Bulk weights (2D linears, MoE experts) take the fast xferdtensor
  broadcast path, broadcasting every tensor as a uint8 byte-view (nccl4py
  cannot broadcast float8 dtypes natively; a uint8 reinterpret is
  wire-identical for any dtype).
- Misc params go through vLLM's load_weights via packed_broadcast: FP8
  blockwise *_scale_inv siblings, FP8 KV/activation scales, and the MoE
  router.

Routing the MoE router (HF mlp.gate.weight / Megatron mlp.router.weight)
to misc fixes an FP8 refit deadlock: Megatron keeps the router in bf16 but
vLLM blockwise-FP8 quantizes it, so a direct xferdtensor broadcast
mismatched byte counts (bf16 send vs fp8 recv) and hung the collective.
load_weights quantizes it correctly.

Verified on the L0/L1 matrix: BF16 L0 (4b/30b across DP/TP/PP/EP) and FP8
L0 dense 4b complete 10/10 steps; FP8 MoE 30b reaches parity with the
BCAST baseline (refit + generation succeed; the remaining get_logprobs TE
error is a pre-existing upstream FP8 issue); BF16 L1 235B and DSv3 complete
10/10. Gen KL stays in the 4e-4..5e-3 range across all configs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Add disaggregated weight-refit support for NemotronH models (e.g.
Nemotron-3-Nano-30B-A3B): hybrid Mamba2 / attention / MoE-FFN.

Key fix: non-gated (ReLU^2) MoE experts silently dropped their
up-projection. NemotronH MoE has up_proj + down_proj (no gate_proj), and
vLLM builds it as SharedFusedMoE(is_act_and_mul=False) with
w13_weight = [E, intermediate, hidden] (up only). fuse_expert_params_in_metadata
did `if not gate_entries: continue`, so it never created w13 and dropped
up_proj -> vLLM's w13_weight stayed stale -> every MoE layer half-corrupt
-> 100% importance-sampling masking, reward ~0. Now build w13 from up_proj
alone when there is no gate, on both the metadata and value-fusion sides.
The gated SwiGLU path is unchanged (if gate_entries / elif up_entries).

Also:
- is_misc_param: route NemotronH Mamba mixer params (in_proj, conv1d,
  A_log, D, dt_bias, norm) and the MoE router (mixer.gate.weight) to the
  misc/load_weights path; out_proj stays on the bulk xferdtensor path.
- placement: out_proj -> row-parallel; NemotronH token embedding ->
  vocab-parallel.
- vllm_backend: translate backbone.* -> model.* and embeddings ->
  embed_tokens at vLLM-param lookup so bulk mapping resolves.

Validated at baseline parity (reward + KL match the broadcast refit) on
NemotronH (non-gated) and Qwen3-30B-A3B (gated SwiGLU regression).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Three related refactors to the nccl_reshard refit path:

1. Tag MoE expert fusion by role + explicit gated flag.
   Replace the overloaded fused_expert_param_type "w13"/"w2" with the
   semantic role "up_proj"/"down_proj" plus a gated:bool (single source of
   truth for the up-projection layout; the value-side branch asserts it
   against gate_proj presence).  Fused param names stay w13_weight/w2_weight
   (vLLM's FusedMoE convention).

2. Move DTensorRef to xferdtensor.py; rename full_tensor() -> local_tensor().
   DTensorRef is the wrapper xferdtensor reads as its src/dst, so it belongs
   next to the kernel.  full_tensor() only ever returned the local shard and
   had no callers (the kernel reads the ._local_tensor attribute).  MeshInfo
   stays in nccl_reshard_utils.

3. Drop the DTensor train backend (Megatron + vLLM only for this initial
   version).  check_nccl_reshard_refit_support states the Megatron-only scope
   intentionally; the grpo.py refit setup collapses to the Megatron branch;
   the unreachable DTensor-worker nccl-refit methods are removed.

Validated at parity across the L0 matrix (4B + 30B parallelism shapes and
NemotronH) on both gated (Qwen3-30B-A3B) and non-gated (NemotronH) MoE.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Two NemotronH-specific gaps blocked the disaggregated refit of
NVIDIA-Nemotron-3-Super-120B-A12B (Latent-MoE, GQA num_kv=2) at
TP8 EP8 PP4; both are now fixed and validated end-to-end (refit at
reward ~0.50-0.52 / gen-KL ~0.004, parity with the broadcast baseline).

1. GQA QKV at tp_size > num_query_groups.  The bulk (xferdtensor) path
   splits the fused QKV on each TP-local shard, which needs every rank
   to own whole K/V heads.  When TP exceeds the KV-head count Megatron
   channel-splits the KV heads, so no rank holds a whole one and the
   local split is impossible.  is_misc_param now takes a qkv_to_misc
   flag (set when num_query_groups < train tp): QKV then takes the misc
   path, which gathers full heads via export_hf_weights like the
   broadcast refit.  _iter_local_hf_params skips QKV in that case so the
   bulk path never attempts the impossible split.  No-op for every other
   model (num_query_groups >= tp_size keeps QKV on the bulk fast path).

2. PP-stage mapping was Llama/Qwen-only.  _LAYER_RE / _MODEL_PREFIX_RE
   and _build_layer_to_pp_stage only knew model.layers.N naming, so
   NemotronH's backbone.* params all defaulted to stage 0 and PP rank 0
   tried to transfer the whole model.  Both are now model|backbone
   naming-agnostic.

Keeps the super-v3 BF16 L1 test (config + baseline/refit scripts).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Decouple build_nccl_reshard_refit_info from vLLM-specific layout rules so
another generation backend could supply its own, and simplify the gen-side
merged-param handling.

- MoE-expert fusion is no longer hardcoded in the builder:
  fuse_expert_params_in_metadata -> vllm_fuse_expert_params_in_metadata,
  passed in as fuse_expert_param_in_metadata_fn (None = no fusion).
  VllmGeneration.expert_fusion_fn() supplies it via the driver.

- Route QKV to the misc/load_weights path whenever KV heads can't be cleanly
  1/tp-sharded on EITHER side:
  qkv_to_misc = num_query_groups < max(train_tp, gen_tp). This subsumes the
  gen-side KV-head-replication case, so the vLLM-specific KV-head slice
  mapping (_compute_tp_local_slice) is removed from the bulk path.

- get_dst_dtensor's merged-param path collapses to a single placement-honoring
  block: receive this component's slice of the merged vLLM param (Shard dst ->
  the 1/gen_tp shard; all-Replicate dst / disable_tp -> the full tensor, which
  is the slice) and copy it in. Golden and the real reshard op now agree on
  dst placements with no path-specific branch.

Validated on the BF16 L0+L1 golden matrix, 12/12: 4b dense (TP/PP/DP), 30b MoE
(EP/TP/PP incl. qkv->misc at gen TP8), 235b, and dsv3 (disable_tp).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
xferdtensor() now dispatches to the real point-to-point reshard
(nccl.xfer.api.XdtensorRedistribute) when it is available and not overridden,
otherwise to the existing broadcast-based golden reference.

- Guarded import: golden-only containers (no nccl.xfer module) auto-fall-back
  to golden; NRL_XFERDTENSOR_GOLDEN=1 forces golden on any container.
- Each rank holds only one side (train ranks own the src shard, gen ranks the
  dst shard); the op takes a one-sided None and derives the absent side from
  the present side's global shape + mesh/placements. PyTorch Shard/Replicate
  placements pass through; meshes go in as nccl.xfer Mesh via _to_xfer_mesh.

Safe to merge ahead of a real-op-capable container thanks to the guarded
import. The xfer-capable nccl4py + matching custom NCCL are supplied by the
container/deps, not by this change. The real op's ~100x per-call latency is a
known nccl_xfer-library issue (handed to the NCCL team); until it's fixed the
golden path stays the default for the async refit loop.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Validated green across the full golden L0/L1 BF16+FP8 matrix plus Nemotron
Nano-v3 (L0) and Super-v3 (L1); no behavior change on tested configs.

- vllm_backend: size merged-param slices from refit_info["gen_tp_size"]
  instead of torch.distributed.get_world_size(), which equals gen TP only
  while gen PP/EP are pinned to 1 -- robust if gen PP is ever added.
- nccl_reshard_utils: reject unsupported gen precision values up front in
  check_nccl_reshard_refit_support (allow only bf16/bfloat16/auto/fp8/unset),
  so a typo or mismatched precision fails the gate instead of deadlocking the
  bulk collective.
- contiguity guards: .view -> .reshape in the MoE w13 gate/up fusion and
  .contiguous().clone() in the golden all-Replicate fast path, so a
  non-contiguous param view can't crash the reshard.
- rename ne_mesh/ex_mesh -> non_expert_mesh/expert_mesh (and dim_maps),
  unifying three naming styles into one.
- docstrings + worked examples for _compute_shard_slices, _fuse_one_moe_param,
  _build_expert_groups, and the get_dst_dtensor unmapped-discard path
  (clarifying it is collective-symmetry participation, not stale-weight
  corruption).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
The low-level op is finalized as "NCCL Xfer", so the refit feature is renamed
from nccl_reshard_* to nccl_xfer_* throughout:
- config knob policy.nccl_reshard_refit -> policy.nccl_xfer_refit
- prepare_/build_/check_nccl_reshard_refit_* -> nccl_xfer_* (+ _async variants)
- the nccl_reshard_refit method -> nccl_xfer_refit; nccl_reshard_refit_info /
  _enabled attrs -> nccl_xfer_*
- module nemo_rl/distributed/nccl_reshard_utils.py -> nccl_xfer_utils.py

Also rename init_pp_comm_groups -> init_per_pp_refit_comm_group (+ _async): it
sets up the per-PP-stage refit comm groups, not the PP group itself.

Dead-code cleanup folded in:
- VOCAB suffix matching uses endswith (was substring `in`)
- drop the unreachable scale_inv MERGE_RULES (scale_inv takes the misc path)
- drop the dropped-DTensor-backend to_local() else-branch and the unused
  DTensorRef.local_tensor() accessor

Behavior-preserving; validated on the golden matrix (BF16 L0 10/10, FP8 L0 9/9,
235b L1 full) -- config-knob read, module import, and method dispatch all
confirmed under the new names.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 force-pushed the youngeunk/new-refit-integration branch from a336e84 to 84863cd Compare June 10, 2026 05:51
youngeunkwon0405 and others added 3 commits June 10, 2026 16:19
…olish

megatron_policy_worker.py:
- Rename _iter_local_hf_params -> _iter_local_hf_param_shards (it yields
  TP/EP-local shards, not gathered params) and document that for EP>1
  refit_conversion_tasks already holds only this rank's local experts.
- Collapse the dead FP8 _scale_inv splitting branch into a plain skip: all
  _scale_inv siblings are routed to the misc path by is_misc_param, so they
  never enter the bulk state_dict_metadata and their _param_map entries were
  never read. Document that refit_conversion_tasks is the FULL Bridge list and
  is filtered downstream (misc_meta / _misc_conversion_tasks), not pruned.
- Drop the dead `if self.refit_conversion_tasks is None` guard in
  prepare_nccl_xfer_refit_info (refit_param_info_mcore is always recomputed).
- Comment/docstring polish only.

nccl_xfer_utils.py:
- build_nccl_xfer_refit_info: require layer_to_pp_stage when pp_size > 1
  (assert) instead of silently disabling per-stage handling; clarify the
  state_dict_metadata key is the HF param name.

No behavior change; validated by the golden L0/L1 matrix and the 235B FP8
refit run (job 12711863) which imports exactly this code.

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…sion gen-side)

Mirror the QKV pattern for MoE experts: the train worker and the shared
nccl_xfer_refit_info now use backend-agnostic grouped HF names
(gate_proj/up_proj/down_proj, each [E, ...]); the vLLM w13/w2 fusion moves
entirely to the gen side.  A new gen backend now needs only a gen-side mapping,
never a train-worker change.

nccl_xfer_utils.py:
- vllm_fuse_expert_params_in_metadata -> agnostic group_expert_params_in_metadata:
  stack per-expert HF params into one [E, *per_expert_shape] entry per projection,
  tagged grouped_expert_proj.  No w13/w2, no gated, no gen_tp.  The builder calls
  it unconditionally (grouping is universal); dropped the
  fuse_expert_param_in_metadata_fn injection param.
- Removed w13_weight/w2_weight from COLUMN/ROW_PARALLEL_SUFFIXES; grouped
  gate/up/down resolve via _get_expert_tp_shard_dim -> Shard(1)/Shard(2).

megatron_policy_worker.py:
- _fuse_one_moe_param -> _group_experts: plain torch.stack of the per-expert
  tensors; deleted the gate/up interleave + gen_tp_size + gated logic.
- _get_src_local_tensor dispatches on grouped_expert_proj; dropped gen_tp_size.
- prepare_nccl_xfer_refit_info: dropped the injection param.

vllm_backend.py:
- Generalized merged_param_slice to a multi-dim index tuple.
- New flag-dispatched grouped-expert branch in _build_hf_to_vllm_mapping: gate/up
  -> w13_weight dim-1 halves ([:, :P]/[:, P:2P]) when gated (has_gate pre-pass),
  up/down -> direct (whole w13 / w2).  get_dst_dtensor's existing buffer +
  post_refit_hook places the received Shard(1)/Shard(2) shard.

vllm_generation.py / grpo.py / lm_policy.py: removed expert_fusion_fn + all
fuse_expert_param_in_metadata_fn threading.

w13_weight/w2_weight no longer appear in the train worker or the shared metadata
— only in vllm_backend (where vLLM-specific layout belongs).

Validated on the golden path (NRL_XFERDTENSOR_GOLDEN=true): byte-identical by
construction (the gen-side slice reproduces the old interleave layout) + reward/KL
parity across Qwen3-30B-A3B gated (gen_tp=8), NemotronH-nano non-gated, FP8 30B
(scale_inv stays misc), and 30B PP2+TP (per-stage) — all with zero grouped-mapping
warnings and KL in the expected BF16/FP8 ranges.

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…Store workaround)

FP8 KV cache now works on the nccl_xfer refit path: the per-layer k/v(/q)
scales ride the existing misc packed-broadcast as plain scale tensors, instead
of an out-of-band TCPStore scalar channel.  This works because is_misc_param
already routes .k_scale/.v_scale/.q_scale to the misc path and their names are
already in misc_meta (built from _iter_params_with_optional_kv_scales at prepare
time), so producer/gen agree and the scales broadcast like any other misc param.

Train (megatron_policy_worker.py):
- nccl_xfer_refit: drop the NotImplementedError on kv_scales; forward kv_scales
  to _broadcast_misc_params_packed.
- _broadcast_misc_params_packed(kv_scales): forward to
  _iter_params_with_optional_kv_scales(kv_scales=...) so the real calibrated
  scales (grpo's calibrate_qkv_fp8_scales) broadcast instead of the 1.0
  placeholder.

Gen (vllm_backend.py):
- nccl_xfer_refit: after the misc receive, call _maybe_process_fp8_kv_cache()
  (guarded on kv_cache_dtype=fp8) to finalize the per-layer scales — exactly the
  process_weights_after_loading step the prior comment said this work would need.

Test: script/new_refit/fp8/4b_tp2_tp4_fp8kv.sh (FP8 4b TP2DP4->TP4DP2 +
kv_cache_dtype=fp8).  Note FP8 KV cache requires SYNC rollouts
(async_grpo=false, async_engine=false) per grpo.py's guard; disaggregated +
sync is valid (only the inverse — async requires non-colocated — is enforced).

Validated golden (job 12718430): refit reached, recompute_kv_scales runs each
step, Gen KL 0.0086-0.0122 (sane FP8-KV range, ~3x FP8-weights-only as expected
from KV quantization), no grouped-mapping / kv-scale / process_weights errors.
Also confirms sync-mode non-colocated nccl_xfer refit works.

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@youngeunkwon0405 youngeunkwon0405 added the CI:L1 Run doctests, unit tests, and functional tests label Jun 11, 2026
@youngeunkwon0405

Copy link
Copy Markdown
Contributor Author

/okay to test f3834e0

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405

Copy link
Copy Markdown
Contributor Author

/okay to test 46d09b7

… trim

- _build_expert_groups resolves each per-expert name to its _param_map view
  once at prepare time (index-sorted), so _group_experts only torch.stacks at
  refit time — the hot path no longer does per-expert dict lookups.
- Replace silent fallbacks with loud asserts: _group_experts asserts its expert
  group is non-empty, and nccl_xfer_refit asserts the per-param local tensor is
  present (a None means a layer_to_pp_stage / metadata inconsistency that would
  otherwise surface as a gen-side hang, not a clear error).
- Add [xferd-payload] byte-accounting print (xfer-major vs misc bytes per refit).
- xferdtensor.py: inline `# pyrefly: ignore[...]` for the optional nccl.xfer
  import path (import-error / not-callable / bad-argument-type).
- Comment/docstring cleanup across the refit path, incl. the MLA down-proj
  replication note (TELinear duplicated -> replicated at any TP for the TE spec;
  warns about the local-spec / mla_down_proj_fusion sharding cases).

Behavior is byte-identical; validated on the golden L0 matrix (4b DP/TP/PP x6,
30b EP/PP x4) all 10/10, KL 0.0004-0.0019, asserts never fired.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 force-pushed the youngeunk/new-refit-integration branch from 3708bea to 3288f55 Compare June 12, 2026 03:46
@youngeunkwon0405 youngeunkwon0405 changed the title [WIP] New refit integration branch feat: NCCL Xfer based refit integration Jun 12, 2026
…fail-safe)

Invert the bulk/misc classifier from a blacklist (is_misc_param: route named
exceptions to misc, everything else bulk) to a whitelist (is_nccl_xfer_param:
bulk only for params with a known TP/EP shard rule or a grouped MoE expert
projection; everything else -> misc).  An unrecognized param (new model / layer
type) now routes to the misc/load_weights path — correct-but-slower, same
coverage as the conventional broadcast refit — instead of silently taking the
bulk path with a possibly-wrong uniform column/row/replicate placement.

- nccl_xfer_utils: is_misc_param -> is_nccl_xfer_param (whitelist: qkv carve-out
  + exact expert regex + .experts. guard + explicit column/row/vocab suffix loop).
- megatron_policy_worker: flip the state_dict_metadata/misc_meta split.  Derive
  _misc_conversion_tasks from misc_meta (the HF-name split) rather than
  re-classifying the Megatron global_param_name — those are different namespaces,
  and classifying Megatron compound names (linear_qkv / linear_fc1 / linear_proj)
  with the HF-suffix whitelist mis-routes bulk tasks into misc and desyncs the
  packed_broadcast.
- vllm_backend + comments: reference the new classifier; fix a stale
  "embeddings -> misc" note (embeddings are vocab-parallel bulk).

Validated golden: 4b_dp_dp / 30b_ep8pp2_tp8 (qkv_to_misc=True) / nemotron-nano
(non-gated MoE + Mamba) / fp8 4b_tp2_tp4 (_scale_inv) all 10/10, KL 0.0004-0.0033.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 added the Performance Related to improving performance label Jun 12, 2026
@youngeunkwon0405

Copy link
Copy Markdown
Contributor Author

/okay to test 4a93573

Pipeline the gen-side refit receive across CUDA streams to hide the
per-stage broadcast latency when train PP>1 funnels every stage through
PP1 gen ranks (DSv3, 235B).

- Group received params by PP stage and round-robin them over
  NRL_REFIT_NUM_STREAMS streams (default 2; 1 = serial). Each stage's
  collective is pinned to one stream (NCCL-safe), and a per-stage event
  gates to num_streams stages in flight so peak receive-buffer memory
  stays bounded. ~2x weight_sync on DSv3 (PP16->PP1: ~15s -> ~8s);
  K=2 is the throughput knee, K=4 over-subscribes the fabric.
- empty_cache() after the refit returns the transient receive buffers to
  the driver so the gen engine's out-of-pool allocations (cuBLAS) have
  headroom; without it, memory-tight gen configs (FP8 235B at gen TP=4,
  ~12 GiB non-pool headroom) OOM on the first post-refit forward.
- Harden the HF->gen mapping: _build_hf_to_gen_backend_mapping now raises
  on an unmapped bulk param (grouped-expert and general branches) instead
  of silently returning (None, None), and get_dst_dtensor raises rather
  than receiving into a discarded buffer. A coverage regression fails
  loud at build time instead of silently dropping weights.
- Extract _fused_param_merge_slice for the qkv/gate_up/fused_qkv_a
  dim-0 sub-slice math; normalize merged slices to index tuples.

Validated on the golden L0+L1 matrix (BF16 + FP8) plus nemotron
nano/super: refit-clean across DP/TP/PP/EP, gated and non-gated MoE,
FP8, 235B, and DSv3 (K=2). FP8 235B confirmed fixed by empty_cache
(0 OOM, KL 0.0110, weight_sync ~6s).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests Performance Related to improving performance

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant