feat: NCCL Xfer based refit integration#2413
Draft
youngeunkwon0405 wants to merge 29 commits into
Draft
Conversation
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com> fix DTensor nccl_reshard_refit dtype mismatch
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com> code refactoring Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com> code refactoring Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Build the fused w13_weight in train-side _fuse_expert_params with a TP-aware permute so contiguous Shard(1) slicing yields vLLM's expected per-rank [half_gate, half_up] layout (per fused_moe/layer.py::_load_w13). Drops 30B EP8PP2 refit Generation KL from ~5.2/5.7 to 0.0018, matching the dense baseline (~0.0004). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Contributor
Author
|
/claude review |
Review notesMissing exemplar YAML update: The new Missing test coverage: |
…ridge
Wires up FP8→FP8 weight transfer through Bridge's `build_export_fp8_tasks`
so the train side ships TE Float8BlockwiseQTensor data + scale_inv pairs
without recreating ~1000 lines of TE-coupled extraction logic ourselves.
Train side (megatron_policy_worker.py):
- `_is_fp8_export()` returns True iff fp8_cfg has fp8_param=true and
fp8_recipe="blockwise" (the only Megatron config TE produces FP8
storage for).
- `_build_refit_conversion_tasks()` routes between
`bridge.get_conversion_tasks` (BF16 / fp8_param=false) and
`bridge._model_bridge.build_export_fp8_tasks` (fp8_param=true). The
FP8 path emits paired (FP8 weight, scale_inv) tasks.
- `_iter_local_hf_params` recognises Bridge's `_HFNameSuffixMapping`
scale_inv tasks. For compound mappings the scale tensor is split the
same way the weight is:
* QKVMapping → reuse `split_qkv_weights` (Bridge auto-detects
scale-domain trailing-dim block compression)
* GatedMLPMapping → chunk by 2 along dim 0 (gate / up halves)
* simple mapping → yield with `_scale_inv` HF-name suffix
- `_calculate_refit_param_info`: extend `prec_to_bytes` with
`float8_e4m3fn`, `float8_e5m2`, `uint8` so the size accounting works
on FP8 weights.
Gen side (vllm_backend.py):
- `_build_hf_to_vllm_mapping` MERGE_RULES extended with `*_scale_inv`
variants so HF q/k/v_proj.weight_scale_inv merge into vLLM
qkv_proj.weight_scale_inv (same for gate/up). Direct-match
scale_inv (e.g. o_proj.weight_scale_inv) flows through the existing
1:1 path. Existing `_compute_tp_local_slice` handles dim-0 sharding
for the blockwise scale shape `[output//block, input//block]`.
Test scripts (script/new_refit/):
- `4b_fp8_to_fp8_dp_dp.sh` — Qwen3-4B dense, fp8_param=true train →
vLLM precision=fp8 use_deep_gemm=true. NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
set as direct shell export (Hydra rejects new env_vars keys without
`+` and rejects non-string values without escape engineering that
doesn't survive bash double-quoting).
- `4b_bf16_to_fp8_dp_dp.sh` — BF16 train → vLLM FP8. Documented as
failing today: needs gen-side quantization in finalize.
Verification (FP8→FP8, job 11586681): 10/10 GRPO steps complete,
initial refit + per-step refits all succeed, Generation KL
0.0029-0.0045 (BF16 baseline ~0.0004; ~10× higher is consistent with
FP8 quantization noise), weight_sync ~0.07s/step, clean exit.
Existing BF16 path verified non-regressive (4b TP2→TP4 KL 0.0004).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
- get_placements: skip TP shard for expert params when EP>1 (current scope: expert_tensor_parallel_size=1 only); use len(dim_map) for mesh dim count. - Remove vestigial fields _moe_fused, _moe_num_experts_local, _moe_gate_entries, _moe_up_entries, _moe_down_entries; replace with a single _moe_kind: "w13" | "w2" discriminator. - Drop metadata_ep_gathered flag and ep_size parameter; the EP-gathered invariant on input metadata is now asserted in the docstring (Megatron meets it via export_hf_weights' EP all-gather; DTensor doesn't use EP). Net diff: -64 / +34 lines. Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
dad5cb6 to
a543fb2
Compare
Makes the nccl_reshard-based refit path work correctly and within memory for Qwen3-MoE training configs that combine expert and tensor parallelism (e.g. EP4×TP2). Three independent issues: 1. OOM during metadata enumeration. Bridge's export_hf_weights does PP/TP/EP gathers to materialize full unsharded tensors, but the refit-info builders only need shape+dtype. Wrap the enumeration in a context that redirects empty_like/zeros_like to "meta" and turns the collectives into no-ops; peak memory stays at zero extra GiB. Lifted to a module-level helper so prepare_refit_info (BCAST path) and prepare_nccl_reshard_refit_info share it. 2. OOM during expert fusion. Previously all layers' fused w13/w2 tensors were materialized upfront (~14 GiB peak on 30B). Rebuild one fused tensor at a time inside the transfer loop and free immediately after; peak drops to one fused param. 3. KL divergence on ETP=1 EP×TP configs. Megatron-Core with expert_tensor_parallel_size=1 gives non-expert and expert params *different* rank-to-coord layouts on the same physical ranks (e.g. EP4×TP2 has tp=r%2 AND ep=r%4 on the same rank — not a single product structure). build_nccl_reshard_refit_info now constructs two src meshes per PP stage — one for non-expert params (tp×dp) and one for expert params (ep×edp) — and selects per-param based on is_expert_param. Verified on the full 11-config test matrix. KL on previously-broken configs: - 30b EP4TP2→TP4: 1.46 → 0.002 - 30b EP4TP2PP2→TP4: 1.39 → 0.002 Other configs (1b/4b/30b EP8) unchanged within float tolerance. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
These per-user SLURM submission scripts are local conveniences and shouldn't live in the repo. Keep them on disk (still gitignore-able) but stop tracking. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
`_param_map` (HF-name → Megatron-param view), `_expert_groups` (regex grouping of expert params by layer/proj), and `_hf_to_vllm` (HF→vLLM merge-rule mapping) all depend only on model topology, not weight values. Build them once in the respective prepare_nccl_reshard_refit_info calls and reuse for every refit step. Saves one full Megatron-param iterator pass + one regex scan + one vLLM named_parameters walk per refit step. Also: dedup the per-expert regex (reuse _INDIVIDUAL_EXPERT_RE from nccl_reshard_utils), drop leftover bringup print()s in init_pp_comm_groups, and remove a WHAT-narrating comment. Verified on the full L1 matrix (11 small + dsv3 + qwen3-235b); KL values unchanged within run-to-run variance. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
e8a3919 to
96613c3
Compare
946394e to
181457c
Compare
Contributor
Author
|
/claude review |
…c paths Extend the nccl_reshard (xferdtensor) disaggregated refit to FP8 weights with a hybrid transfer: - Bulk weights (2D linears, MoE experts) take the fast xferdtensor broadcast path, broadcasting every tensor as a uint8 byte-view (nccl4py cannot broadcast float8 dtypes natively; a uint8 reinterpret is wire-identical for any dtype). - Misc params go through vLLM's load_weights via packed_broadcast: FP8 blockwise *_scale_inv siblings, FP8 KV/activation scales, and the MoE router. Routing the MoE router (HF mlp.gate.weight / Megatron mlp.router.weight) to misc fixes an FP8 refit deadlock: Megatron keeps the router in bf16 but vLLM blockwise-FP8 quantizes it, so a direct xferdtensor broadcast mismatched byte counts (bf16 send vs fp8 recv) and hung the collective. load_weights quantizes it correctly. Verified on the L0/L1 matrix: BF16 L0 (4b/30b across DP/TP/PP/EP) and FP8 L0 dense 4b complete 10/10 steps; FP8 MoE 30b reaches parity with the BCAST baseline (refit + generation succeed; the remaining get_logprobs TE error is a pre-existing upstream FP8 issue); BF16 L1 235B and DSv3 complete 10/10. Gen KL stays in the 4e-4..5e-3 range across all configs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Add disaggregated weight-refit support for NemotronH models (e.g. Nemotron-3-Nano-30B-A3B): hybrid Mamba2 / attention / MoE-FFN. Key fix: non-gated (ReLU^2) MoE experts silently dropped their up-projection. NemotronH MoE has up_proj + down_proj (no gate_proj), and vLLM builds it as SharedFusedMoE(is_act_and_mul=False) with w13_weight = [E, intermediate, hidden] (up only). fuse_expert_params_in_metadata did `if not gate_entries: continue`, so it never created w13 and dropped up_proj -> vLLM's w13_weight stayed stale -> every MoE layer half-corrupt -> 100% importance-sampling masking, reward ~0. Now build w13 from up_proj alone when there is no gate, on both the metadata and value-fusion sides. The gated SwiGLU path is unchanged (if gate_entries / elif up_entries). Also: - is_misc_param: route NemotronH Mamba mixer params (in_proj, conv1d, A_log, D, dt_bias, norm) and the MoE router (mixer.gate.weight) to the misc/load_weights path; out_proj stays on the bulk xferdtensor path. - placement: out_proj -> row-parallel; NemotronH token embedding -> vocab-parallel. - vllm_backend: translate backbone.* -> model.* and embeddings -> embed_tokens at vLLM-param lookup so bulk mapping resolves. Validated at baseline parity (reward + KL match the broadcast refit) on NemotronH (non-gated) and Qwen3-30B-A3B (gated SwiGLU regression). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Three related refactors to the nccl_reshard refit path: 1. Tag MoE expert fusion by role + explicit gated flag. Replace the overloaded fused_expert_param_type "w13"/"w2" with the semantic role "up_proj"/"down_proj" plus a gated:bool (single source of truth for the up-projection layout; the value-side branch asserts it against gate_proj presence). Fused param names stay w13_weight/w2_weight (vLLM's FusedMoE convention). 2. Move DTensorRef to xferdtensor.py; rename full_tensor() -> local_tensor(). DTensorRef is the wrapper xferdtensor reads as its src/dst, so it belongs next to the kernel. full_tensor() only ever returned the local shard and had no callers (the kernel reads the ._local_tensor attribute). MeshInfo stays in nccl_reshard_utils. 3. Drop the DTensor train backend (Megatron + vLLM only for this initial version). check_nccl_reshard_refit_support states the Megatron-only scope intentionally; the grpo.py refit setup collapses to the Megatron branch; the unreachable DTensor-worker nccl-refit methods are removed. Validated at parity across the L0 matrix (4B + 30B parallelism shapes and NemotronH) on both gated (Qwen3-30B-A3B) and non-gated (NemotronH) MoE. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Two NemotronH-specific gaps blocked the disaggregated refit of NVIDIA-Nemotron-3-Super-120B-A12B (Latent-MoE, GQA num_kv=2) at TP8 EP8 PP4; both are now fixed and validated end-to-end (refit at reward ~0.50-0.52 / gen-KL ~0.004, parity with the broadcast baseline). 1. GQA QKV at tp_size > num_query_groups. The bulk (xferdtensor) path splits the fused QKV on each TP-local shard, which needs every rank to own whole K/V heads. When TP exceeds the KV-head count Megatron channel-splits the KV heads, so no rank holds a whole one and the local split is impossible. is_misc_param now takes a qkv_to_misc flag (set when num_query_groups < train tp): QKV then takes the misc path, which gathers full heads via export_hf_weights like the broadcast refit. _iter_local_hf_params skips QKV in that case so the bulk path never attempts the impossible split. No-op for every other model (num_query_groups >= tp_size keeps QKV on the bulk fast path). 2. PP-stage mapping was Llama/Qwen-only. _LAYER_RE / _MODEL_PREFIX_RE and _build_layer_to_pp_stage only knew model.layers.N naming, so NemotronH's backbone.* params all defaulted to stage 0 and PP rank 0 tried to transfer the whole model. Both are now model|backbone naming-agnostic. Keeps the super-v3 BF16 L1 test (config + baseline/refit scripts). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Decouple build_nccl_reshard_refit_info from vLLM-specific layout rules so another generation backend could supply its own, and simplify the gen-side merged-param handling. - MoE-expert fusion is no longer hardcoded in the builder: fuse_expert_params_in_metadata -> vllm_fuse_expert_params_in_metadata, passed in as fuse_expert_param_in_metadata_fn (None = no fusion). VllmGeneration.expert_fusion_fn() supplies it via the driver. - Route QKV to the misc/load_weights path whenever KV heads can't be cleanly 1/tp-sharded on EITHER side: qkv_to_misc = num_query_groups < max(train_tp, gen_tp). This subsumes the gen-side KV-head-replication case, so the vLLM-specific KV-head slice mapping (_compute_tp_local_slice) is removed from the bulk path. - get_dst_dtensor's merged-param path collapses to a single placement-honoring block: receive this component's slice of the merged vLLM param (Shard dst -> the 1/gen_tp shard; all-Replicate dst / disable_tp -> the full tensor, which is the slice) and copy it in. Golden and the real reshard op now agree on dst placements with no path-specific branch. Validated on the BF16 L0+L1 golden matrix, 12/12: 4b dense (TP/PP/DP), 30b MoE (EP/TP/PP incl. qkv->misc at gen TP8), 235b, and dsv3 (disable_tp). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
xferdtensor() now dispatches to the real point-to-point reshard (nccl.xfer.api.XdtensorRedistribute) when it is available and not overridden, otherwise to the existing broadcast-based golden reference. - Guarded import: golden-only containers (no nccl.xfer module) auto-fall-back to golden; NRL_XFERDTENSOR_GOLDEN=1 forces golden on any container. - Each rank holds only one side (train ranks own the src shard, gen ranks the dst shard); the op takes a one-sided None and derives the absent side from the present side's global shape + mesh/placements. PyTorch Shard/Replicate placements pass through; meshes go in as nccl.xfer Mesh via _to_xfer_mesh. Safe to merge ahead of a real-op-capable container thanks to the guarded import. The xfer-capable nccl4py + matching custom NCCL are supplied by the container/deps, not by this change. The real op's ~100x per-call latency is a known nccl_xfer-library issue (handed to the NCCL team); until it's fixed the golden path stays the default for the async refit loop. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Validated green across the full golden L0/L1 BF16+FP8 matrix plus Nemotron Nano-v3 (L0) and Super-v3 (L1); no behavior change on tested configs. - vllm_backend: size merged-param slices from refit_info["gen_tp_size"] instead of torch.distributed.get_world_size(), which equals gen TP only while gen PP/EP are pinned to 1 -- robust if gen PP is ever added. - nccl_reshard_utils: reject unsupported gen precision values up front in check_nccl_reshard_refit_support (allow only bf16/bfloat16/auto/fp8/unset), so a typo or mismatched precision fails the gate instead of deadlocking the bulk collective. - contiguity guards: .view -> .reshape in the MoE w13 gate/up fusion and .contiguous().clone() in the golden all-Replicate fast path, so a non-contiguous param view can't crash the reshard. - rename ne_mesh/ex_mesh -> non_expert_mesh/expert_mesh (and dim_maps), unifying three naming styles into one. - docstrings + worked examples for _compute_shard_slices, _fuse_one_moe_param, _build_expert_groups, and the get_dst_dtensor unmapped-discard path (clarifying it is collective-symmetry participation, not stale-weight corruption). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
The low-level op is finalized as "NCCL Xfer", so the refit feature is renamed from nccl_reshard_* to nccl_xfer_* throughout: - config knob policy.nccl_reshard_refit -> policy.nccl_xfer_refit - prepare_/build_/check_nccl_reshard_refit_* -> nccl_xfer_* (+ _async variants) - the nccl_reshard_refit method -> nccl_xfer_refit; nccl_reshard_refit_info / _enabled attrs -> nccl_xfer_* - module nemo_rl/distributed/nccl_reshard_utils.py -> nccl_xfer_utils.py Also rename init_pp_comm_groups -> init_per_pp_refit_comm_group (+ _async): it sets up the per-PP-stage refit comm groups, not the PP group itself. Dead-code cleanup folded in: - VOCAB suffix matching uses endswith (was substring `in`) - drop the unreachable scale_inv MERGE_RULES (scale_inv takes the misc path) - drop the dropped-DTensor-backend to_local() else-branch and the unused DTensorRef.local_tensor() accessor Behavior-preserving; validated on the golden matrix (BF16 L0 10/10, FP8 L0 9/9, 235b L1 full) -- config-knob read, module import, and method dispatch all confirmed under the new names. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
a336e84 to
84863cd
Compare
…olish megatron_policy_worker.py: - Rename _iter_local_hf_params -> _iter_local_hf_param_shards (it yields TP/EP-local shards, not gathered params) and document that for EP>1 refit_conversion_tasks already holds only this rank's local experts. - Collapse the dead FP8 _scale_inv splitting branch into a plain skip: all _scale_inv siblings are routed to the misc path by is_misc_param, so they never enter the bulk state_dict_metadata and their _param_map entries were never read. Document that refit_conversion_tasks is the FULL Bridge list and is filtered downstream (misc_meta / _misc_conversion_tasks), not pruned. - Drop the dead `if self.refit_conversion_tasks is None` guard in prepare_nccl_xfer_refit_info (refit_param_info_mcore is always recomputed). - Comment/docstring polish only. nccl_xfer_utils.py: - build_nccl_xfer_refit_info: require layer_to_pp_stage when pp_size > 1 (assert) instead of silently disabling per-stage handling; clarify the state_dict_metadata key is the HF param name. No behavior change; validated by the golden L0/L1 matrix and the 235B FP8 refit run (job 12711863) which imports exactly this code. Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com> Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…sion gen-side) Mirror the QKV pattern for MoE experts: the train worker and the shared nccl_xfer_refit_info now use backend-agnostic grouped HF names (gate_proj/up_proj/down_proj, each [E, ...]); the vLLM w13/w2 fusion moves entirely to the gen side. A new gen backend now needs only a gen-side mapping, never a train-worker change. nccl_xfer_utils.py: - vllm_fuse_expert_params_in_metadata -> agnostic group_expert_params_in_metadata: stack per-expert HF params into one [E, *per_expert_shape] entry per projection, tagged grouped_expert_proj. No w13/w2, no gated, no gen_tp. The builder calls it unconditionally (grouping is universal); dropped the fuse_expert_param_in_metadata_fn injection param. - Removed w13_weight/w2_weight from COLUMN/ROW_PARALLEL_SUFFIXES; grouped gate/up/down resolve via _get_expert_tp_shard_dim -> Shard(1)/Shard(2). megatron_policy_worker.py: - _fuse_one_moe_param -> _group_experts: plain torch.stack of the per-expert tensors; deleted the gate/up interleave + gen_tp_size + gated logic. - _get_src_local_tensor dispatches on grouped_expert_proj; dropped gen_tp_size. - prepare_nccl_xfer_refit_info: dropped the injection param. vllm_backend.py: - Generalized merged_param_slice to a multi-dim index tuple. - New flag-dispatched grouped-expert branch in _build_hf_to_vllm_mapping: gate/up -> w13_weight dim-1 halves ([:, :P]/[:, P:2P]) when gated (has_gate pre-pass), up/down -> direct (whole w13 / w2). get_dst_dtensor's existing buffer + post_refit_hook places the received Shard(1)/Shard(2) shard. vllm_generation.py / grpo.py / lm_policy.py: removed expert_fusion_fn + all fuse_expert_param_in_metadata_fn threading. w13_weight/w2_weight no longer appear in the train worker or the shared metadata — only in vllm_backend (where vLLM-specific layout belongs). Validated on the golden path (NRL_XFERDTENSOR_GOLDEN=true): byte-identical by construction (the gen-side slice reproduces the old interleave layout) + reward/KL parity across Qwen3-30B-A3B gated (gen_tp=8), NemotronH-nano non-gated, FP8 30B (scale_inv stays misc), and 30B PP2+TP (per-stage) — all with zero grouped-mapping warnings and KL in the expected BF16/FP8 ranges. Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com> Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…Store workaround) FP8 KV cache now works on the nccl_xfer refit path: the per-layer k/v(/q) scales ride the existing misc packed-broadcast as plain scale tensors, instead of an out-of-band TCPStore scalar channel. This works because is_misc_param already routes .k_scale/.v_scale/.q_scale to the misc path and their names are already in misc_meta (built from _iter_params_with_optional_kv_scales at prepare time), so producer/gen agree and the scales broadcast like any other misc param. Train (megatron_policy_worker.py): - nccl_xfer_refit: drop the NotImplementedError on kv_scales; forward kv_scales to _broadcast_misc_params_packed. - _broadcast_misc_params_packed(kv_scales): forward to _iter_params_with_optional_kv_scales(kv_scales=...) so the real calibrated scales (grpo's calibrate_qkv_fp8_scales) broadcast instead of the 1.0 placeholder. Gen (vllm_backend.py): - nccl_xfer_refit: after the misc receive, call _maybe_process_fp8_kv_cache() (guarded on kv_cache_dtype=fp8) to finalize the per-layer scales — exactly the process_weights_after_loading step the prior comment said this work would need. Test: script/new_refit/fp8/4b_tp2_tp4_fp8kv.sh (FP8 4b TP2DP4->TP4DP2 + kv_cache_dtype=fp8). Note FP8 KV cache requires SYNC rollouts (async_grpo=false, async_engine=false) per grpo.py's guard; disaggregated + sync is valid (only the inverse — async requires non-colocated — is enforced). Validated golden (job 12718430): refit reached, recompute_kv_scales runs each step, Gen KL 0.0086-0.0122 (sane FP8-KV range, ~3x FP8-weights-only as expected from KV quantization), no grouped-mapping / kv-scale / process_weights errors. Also confirms sync-mode non-colocated nccl_xfer refit works. Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com> Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Contributor
Author
|
/okay to test f3834e0 |
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Contributor
Author
|
/okay to test 46d09b7 |
… trim - _build_expert_groups resolves each per-expert name to its _param_map view once at prepare time (index-sorted), so _group_experts only torch.stacks at refit time — the hot path no longer does per-expert dict lookups. - Replace silent fallbacks with loud asserts: _group_experts asserts its expert group is non-empty, and nccl_xfer_refit asserts the per-param local tensor is present (a None means a layer_to_pp_stage / metadata inconsistency that would otherwise surface as a gen-side hang, not a clear error). - Add [xferd-payload] byte-accounting print (xfer-major vs misc bytes per refit). - xferdtensor.py: inline `# pyrefly: ignore[...]` for the optional nccl.xfer import path (import-error / not-callable / bad-argument-type). - Comment/docstring cleanup across the refit path, incl. the MLA down-proj replication note (TELinear duplicated -> replicated at any TP for the TE spec; warns about the local-spec / mla_down_proj_fusion sharding cases). Behavior is byte-identical; validated on the golden L0 matrix (4b DP/TP/PP x6, 30b EP/PP x4) all 10/10, KL 0.0004-0.0019, asserts never fired. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
3708bea to
3288f55
Compare
…fail-safe) Invert the bulk/misc classifier from a blacklist (is_misc_param: route named exceptions to misc, everything else bulk) to a whitelist (is_nccl_xfer_param: bulk only for params with a known TP/EP shard rule or a grouped MoE expert projection; everything else -> misc). An unrecognized param (new model / layer type) now routes to the misc/load_weights path — correct-but-slower, same coverage as the conventional broadcast refit — instead of silently taking the bulk path with a possibly-wrong uniform column/row/replicate placement. - nccl_xfer_utils: is_misc_param -> is_nccl_xfer_param (whitelist: qkv carve-out + exact expert regex + .experts. guard + explicit column/row/vocab suffix loop). - megatron_policy_worker: flip the state_dict_metadata/misc_meta split. Derive _misc_conversion_tasks from misc_meta (the HF-name split) rather than re-classifying the Megatron global_param_name — those are different namespaces, and classifying Megatron compound names (linear_qkv / linear_fc1 / linear_proj) with the HF-suffix whitelist mis-routes bulk tasks into misc and desyncs the packed_broadcast. - vllm_backend + comments: reference the new classifier; fix a stale "embeddings -> misc" note (embeddings are vocab-parallel bulk). Validated golden: 4b_dp_dp / 30b_ep8pp2_tp8 (qkv_to_misc=True) / nemotron-nano (non-gated MoE + Mamba) / fp8 4b_tp2_tp4 (_scale_inv) all 10/10, KL 0.0004-0.0033. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Contributor
Author
|
/okay to test 4a93573 |
Pipeline the gen-side refit receive across CUDA streams to hide the per-stage broadcast latency when train PP>1 funnels every stage through PP1 gen ranks (DSv3, 235B). - Group received params by PP stage and round-robin them over NRL_REFIT_NUM_STREAMS streams (default 2; 1 = serial). Each stage's collective is pinned to one stream (NCCL-safe), and a per-stage event gates to num_streams stages in flight so peak receive-buffer memory stays bounded. ~2x weight_sync on DSv3 (PP16->PP1: ~15s -> ~8s); K=2 is the throughput knee, K=4 over-subscribes the fabric. - empty_cache() after the refit returns the transient receive buffers to the driver so the gen engine's out-of-pool allocations (cuBLAS) have headroom; without it, memory-tight gen configs (FP8 235B at gen TP=4, ~12 GiB non-pool headroom) OOM on the first post-refit forward. - Harden the HF->gen mapping: _build_hf_to_gen_backend_mapping now raises on an unmapped bulk param (grouped-expert and general branches) instead of silently returning (None, None), and get_dst_dtensor raises rather than receiving into a discarded buffer. A coverage regression fails loud at build time instead of silently dropping weights. - Extract _fused_param_merge_slice for the qkv/gate_up/fused_qkv_a dim-0 sub-slice math; normalize merged slices to index tuples. Validated on the golden L0+L1 matrix (BF16 + FP8) plus nemotron nano/super: refit-clean across DP/TP/PP/EP, gated and non-gated MoE, FP8, 235B, and DSv3 (K=2). FP8 235B confirmed fixed by empty_cache (0 OOM, KL 0.0110, weight_sync ~6s). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Current test results
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information