Skip to content

[training] feat: forward MoE/MTP metrics to MLFlow and Comet#3647

Merged
cuichenx merged 2 commits into
NVIDIA-NeMo:mainfrom
lonexreb:training/moe-comet-fanout-2989
May 5, 2026
Merged

[training] feat: forward MoE/MTP metrics to MLFlow and Comet#3647
cuichenx merged 2 commits into
NVIDIA-NeMo:mainfrom
lonexreb:training/moe-comet-fanout-2989

Conversation

@lonexreb

@lonexreb lonexreb commented May 4, 2026

Copy link
Copy Markdown
Contributor

Summary

MCore's track_moe_metrics and track_mtp_metrics only forward metrics to TensorBoard and W&B. Users wiring up Comet (or MLFlow) never see MoE auxiliary losses (load_balancing_loss, seq_load_balancing_loss, global_load_balancing_loss, z_loss) or MTP per-layer losses on those backends.

Per maintainer guidance on the issue"Megatron-Bridge side can monkey-patch track_moe_metrics to avoid a cross-repo dependency" — wrap the TB writer with a small SummaryWriter-shaped adapter that fans out every add_scalar(name, value, iteration) call to MLFlow and Comet using the same per-step value TB receives. W&B is unaffected — the underlying MCore functions still receive wandb_writer directly.

Refs #2989.

Why a wrapper, not total_loss_dict

The original report tried reading from total_loss_dict for the Comet path. That's wrong because total_loss_dict accumulates with += across iterations, so values monotonically grow (the issue records 1.24 → 6.14 instead of the correct 1.20–1.24 per-step range). The wrapper approach captures the exact per-step averaged value that TB receives (loss_list.sum() / num_moe_layers) without any further bookkeeping.

Implementation

src/megatron/bridge/training/utils/train_utils.py:

  • New _MoeMetricFanoutWriter adapter — add_scalar(name, value, iteration) forwards to TB (if any), MLFlow (log_metrics(..., step=iteration)), and Comet (log_metrics(..., step=iteration)).
  • New _build_moe_metric_writer(tb_writer, comet_logger, mlflow_logger) factory:
    • Returns the real TB writer unchanged when neither Comet nor MLFlow is wired up — zero overhead, no behavior change.
    • Returns the wrapper when at least one of Comet / MLFlow is wired up. The wrapper is returned even if tb_writer is None, which is required to surface MoE / MTP metrics in Comet / MLFlow when the user hasn't enabled TensorBoard.
  • Wired into both call sites:
    • track_moe_metrics(..., writer=moe_metric_writer, ...)
    • MTPLossLoggingHelper.track_mtp_metrics(..., mtp_metric_writer, ...)

Tensor sanitation: 0-d torch tensors are converted to Python scalars with .item() before being handed to MLFlow / Comet — TB tolerates 0-d tensors, MLFlow / Comet do not. The TB writer still receives the original value untouched.

Per-layer logging (per_layer_logging=True) is naturally handled — each writer.add_scalar(f\"moe/{name}_layer_{i}\", ...) call inside track_moe_metrics fans out individually.

Test plan

  • python3 -m ast parse of changed files
  • ruff check clean on changed files (auto-fixed import order)
  • ruff format --check clean on changed files
  • 8 new unit tests under TestMoeMetricFanoutWriter:
    • _build_moe_metric_writer returns the original writer when no Comet / MLFlow → bypass contract
    • returns wrapper when only Comet present
    • returns wrapper when only MLFlow present
    • add_scalar fans out to TB + Comet + MLFlow when all three are present
    • add_scalar works when TB writer is None
    • 0-d tensors sanitized to Python floats for Comet / MLFlow but TB receives the tensor untouched
    • Plain Python scalars pass through unchanged
    • Per-layer loop logs each layer individually with the right step
  • CI: existing test_moe_logging* tests continue to pass — those tests do not provide a comet_logger / mlflow_logger, so the helper returns the original writer unchanged (the existing assertions on writer.add_scalar are unaffected).
  • CI: L0 / L1 functional tests on H100 / GB200.
  • Manual: with mlflow_experiment or comet_experiment configured and a MoE model, observe load_balancing_loss (and friends) appearing in MLFlow / Comet at the same per-step values as TensorBoard.

Risk

Low.

  • No public API change.
  • _build_moe_metric_writer returns the original writer object when no fanout target is configured — bit-for-bit equivalent to the previous code path.
  • The wrapper only adds log_metrics(...) calls on the rank that already has those loggers configured (rank N-1 only).
  • W&B path is untouched — wandb_writer is still passed directly to MCore.
  • No change to total_loss_dict semantics.

Notes for reviewers

  • An alternative was forking the helper signature in MCore (comet_logger parameter), but the maintainer explicitly requested a Bridge-side workaround to avoid the cross-repo dependency.
  • The wrapper deliberately mimics the full add_scalar(name, value, iteration) shape rather than monkey-patching track_moe_metrics, so future MCore changes that add new metric emissions through the same writer interface get fanned out automatically.

MCore's `track_moe_metrics` and `track_mtp_metrics` only forward metrics
to TensorBoard and W&B. Users wiring up Comet (or MLFlow) never see MoE
auxiliary losses (load balancing, sequence aux loss, global aux loss,
z_loss) or MTP per-layer losses on those backends — see issue NVIDIA-NeMo#2989.

Per maintainer guidance ("Megatron-Bridge side can monkey-patch
track_moe_metrics to avoid a cross-repo dependency"), this change wraps
the TB writer with a small SummaryWriter-shaped adapter that fans out
every `add_scalar(name, value, iteration)` call to MLFlow and Comet
using the same per-step value. W&B is left untouched — the underlying
MCore functions still receive `wandb_writer` directly so their
dict-based per-layer logging stays unchanged.

When neither MLFlow nor Comet is configured, the helper returns the
real TB writer unchanged — zero overhead and no behavior change.

When at least one of MLFlow / Comet is configured, the wrapper is
returned even if TB itself is None. This is intentional: it surfaces
MoE / MTP metrics in Comet / MLFlow on rank N-1 even when the user
hasn't enabled TensorBoard.

Tensors are sanitized with `.item()` before being handed to MLFlow /
Comet (TB tolerates 0-d tensors; MLFlow / Comet do not). Per-layer
logging fans out one `add_scalar` per layer naturally.

Adds 8 unit tests covering: bypass when no fanout targets, wrapping
when only Comet or only MLFlow is present, fan-out across all sinks,
operation when TB is None, tensor sanitation, plain-scalar passthrough,
and per-layer loop fan-out.

Refs issue NVIDIA-NeMo#2989.

Signed-off-by: lonexreb <reach2shubhankar@gmail.com>
@copy-pr-bot

copy-pr-bot Bot commented May 4, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cuichenx cuichenx left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cuichenx cuichenx added the ready-to-merge PR is approved, current, and only waiting for CI to pass before merge label May 5, 2026
@cuichenx

cuichenx commented May 5, 2026

Copy link
Copy Markdown
Contributor

/ok to test 6de6750

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label May 5, 2026
@cuichenx cuichenx merged commit 05312d5 into NVIDIA-NeMo:main May 5, 2026
59 of 61 checks passed
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label May 6, 2026
gautham-kollu pushed a commit that referenced this pull request May 12, 2026
Signed-off-by: lonexreb <reach2shubhankar@gmail.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
vasunvidia pushed a commit to vasunvidia/Megatron-Bridge that referenced this pull request Jun 10, 2026
…NeMo#3647)

Signed-off-by: lonexreb <reach2shubhankar@gmail.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

track_moe_metrics() does not forward MoE metrics to Comet ML

3 participants