Skip to content

Trainer.compute_loss: fix loss over-counting under TP and EP-as-TP#45994

Merged
vasqu merged 2 commits into
huggingface:mainfrom
AmineDiro:fix-fsdp-ep-loss-scale
May 26, 2026
Merged

Trainer.compute_loss: fix loss over-counting under TP and EP-as-TP#45994
vasqu merged 2 commits into
huggingface:mainfrom
AmineDiro:fix-fsdp-ep-loss-scale

Conversation

@AmineDiro

Copy link
Copy Markdown
Member

What does this PR do?

When using DP + TP or DP+ EP set by the FSDP+EP branch in_build_accelerator_args replicates the same batch across tp_size ranks, the model's per-rank loss is already per_rank_token_sum / global_num_items_in_batch; multiplying by the full num_processes over-counts by tp_size.

Test

  • Model: Random-init Qwen3-MoE (4L, 8E, Hidden=256)
  • Hardware: 1 node × 8 H100
  • Hyperparameters: Context=2k, LR=0, Seed=42
  • Expected Loss: $\log(151936) \approx 11.93$
Row Backend DP × EP Pre-fix Post-fix Job
A fsdp2 8 × 1 11.97 11.97 22153595
B fsdp2 2 × 4 47.88 11.97 22153596
C DS-Z3 8 × 1 11.97 11.97 22152580
D DS-Z2 8 × 1 11.97 11.97 22152581
E DS-Z2 1 × 8 11.97 11.97 22152578
F DS-Z2 2 × 4 11.97 11.97 22153597

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Who can review?

@3outeille @ArthurZucker

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Rocketknight1

Copy link
Copy Markdown
Member

cc @SunMarc maybe, not sure if it's code agent slop!

@vasqu

vasqu commented May 15, 2026

Copy link
Copy Markdown
Contributor

Definitely not code agent slop! There is currently a lot of work going on re FSDP+TP+EP

@AmineDiro

Copy link
Copy Markdown
Member Author

cc @SunMarc maybe, not sure if it's code agent slop!
@Rocketknight1 haha maybe the PR desc made it look like AI slop, we can't have nice PR desc anymore 🤣

This is a real issue for the reported loss at least, it gets inflated by tp_sizex factor

@vasqu

vasqu commented May 25, 2026

Copy link
Copy Markdown
Contributor

I think we can merge @AmineDiro? Just please sync first and sanity check before 🫡

@AmineDiro

Copy link
Copy Markdown
Member Author

@vasqu Updated and ready !

@vasqu vasqu added this pull request to the merge queue May 26, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 26, 2026
@vasqu vasqu added this pull request to the merge queue May 26, 2026
@vasqu

vasqu commented May 26, 2026

Copy link
Copy Markdown
Contributor

Neat, merging!

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to no response for status checks May 26, 2026
@vasqu vasqu merged commit b131ec1 into huggingface:main May 26, 2026
48 of 95 checks passed
@vasqu

vasqu commented May 26, 2026

Copy link
Copy Markdown
Contributor

Force merged, because gh actions are not working properly. Multiple runs already showed that those were only flaky tests (if any failed)

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty, seems like a test is welcome no?

# TP and EP-as-TP ranks see replicated batches; `num_processes` over-counts
# them by `tp_size`. Mirror the divisor used in `_get_num_items_in_batch`.
loss_scale = self.accelerator.num_processes
if (pc := getattr(self.accelerator, "parallelism_config", None)) is not None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you give a more meaningful name than PC please

yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
khushali9 pushed a commit to khushali9/transformers that referenced this pull request Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants