Trainer.compute_loss: fix loss over-counting under TP and EP-as-TP#45994
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
cc @SunMarc maybe, not sure if it's code agent slop! |
|
Definitely not code agent slop! There is currently a lot of work going on re FSDP+TP+EP |
This is a real issue for the reported loss at least, it gets inflated by tp_sizex factor |
|
I think we can merge @AmineDiro? Just please sync first and sanity check before 🫡 |
|
@vasqu Updated and ready ! |
|
Neat, merging! |
|
Force merged, because gh actions are not working properly. Multiple runs already showed that those were only flaky tests (if any failed) |
ArthurZucker
left a comment
There was a problem hiding this comment.
ty, seems like a test is welcome no?
| # TP and EP-as-TP ranks see replicated batches; `num_processes` over-counts | ||
| # them by `tp_size`. Mirror the divisor used in `_get_num_items_in_batch`. | ||
| loss_scale = self.accelerator.num_processes | ||
| if (pc := getattr(self.accelerator, "parallelism_config", None)) is not None: |
There was a problem hiding this comment.
can you give a more meaningful name than PC please
What does this PR do?
When using DP + TP or DP+ EP set by the FSDP+EP branch in
_build_accelerator_argsreplicates the same batch acrosstp_sizeranks, the model's per-rank loss is alreadyper_rank_token_sum / global_num_items_in_batch; multiplying by the fullnum_processesover-counts bytp_size.Test
Code Agent Policy
Who can review?
@3outeille @ArthurZucker