Add MFU helpers#5698
Conversation
Add three pure helpers to trl/trainer/utils.py: - compute_flops_per_token(config, seq_len): training FLOPs per token for a causal LM. Handles dense and MoE (Mixtral, Qwen3-MoE, DeepSeek-V2). Uses the non-causal attention convention (PaLM / Megatron / nanoGPT). - compute_mfu(flops_per_token, tps, world_size, peak_flops): MFU as a percentage. Caller is responsible for correcting tps for cp/sp/tp over-counting. - adjusted_mfu(mfu, config, seq_len): convert non-causal MFU to causal-corrected MFU (Llama / DS Ulysses convention). No integration with SFTTrainer in this PR — these are standalone helpers usable from any training loop. A follow-up PR can wire them into SFTTrainer.log.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| dense_mlp_flops = 2 * 3 * h * config.intermediate_size # interspersed dense layers | ||
| sparse_step = config.decoder_sparse_step | ||
| total_layer_flops = sum( | ||
| attn_flops + (moe_mlp_flops if layer_idx % sparse_step == 0 else dense_mlp_flops) for layer_idx in range(L) |
There was a problem hiding this comment.
MoE branch crashes for Mixtral configs
High Severity
The MoE branch unconditionally accesses config.moe_intermediate_size and config.decoder_sparse_step, but Mixtral (explicitly listed as supported in the docstring) has neither attribute. Mixtral uses intermediate_size for its expert FFN dimension and has all-MoE layers with no dense/sparse interleaving. This causes an AttributeError at runtime for any Mixtral config.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 2b92bab. Configure here.
There was a problem hiding this comment.
is that correct @AmineDiro
I think this branch would benefit from some unit tests
There was a problem hiding this comment.
Yes that's true, I originally had a TODO for Mixtral models. This MFU was specific for Qwen family of models.
|
|
||
| # MoE dispatch: `num_experts_per_tok` is the canonical MoE marker — present on Mixtral, | ||
| # Qwen3-MoE, DeepSeek-V2, etc.; absent on dense configs. | ||
| num_experts_per_tok = getattr(config, "num_experts_per_tok", None) |
There was a problem hiding this comment.
Usage of getattr violates project rules
Low Severity
getattr(config, "num_experts_per_tok", None) violates the AGENTS.md rule that explicitly says to avoid hasattr and getattr, describing their use as "almost always a symptom of overly defensive programming." The rule recommends expressing checks explicitly or dropping the conditional entirely.
Triggered by project rule: ../.ai/AGENTS.md
Reviewed by Cursor Bugbot for commit 2b92bab. Configure here.
There was a problem hiding this comment.
I think here it's the only way
check #5716 |



What does this PR do?
Add three pure helpers to trl/trainer/utils.py:
compute_flops_per_token(config, seq_len): training FLOPs per token for a causal LM. Handles dense and MoE (Mixtral, Qwen3-MoE, DeepSeek-V2). Uses the non-causal attention convention (PaLM / Megatron / nanoGPT).
compute_mfu(flops_per_token, tps, world_size, peak_flops): MFU as a percentage. Caller is responsible for correcting tps for cp/sp/tp over-counting.
adjusted_mfu(mfu, config, seq_len): convert non-causal MFU to causal-corrected MFU (Llama / DS Ulysses convention).
NOTE: for now this defaults to the clutser H100 bf16 flops for
peak_flops_per_device. We'll probably push a dict of flops/dtype/hw to lookup flops in general case.AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Note
Low Risk
Adds new pure math helpers and tests without changing training control flow; main risk is incorrect FLOPs/MFU estimation due to assumptions about model config fields or Transformers version differences.
Overview
Adds three new utility functions in
trl/trainer/utils.pyto estimate training FLOPs per token (compute_flops_per_token) for dense and MoE causal LMs, compute Model FLOPs Utilization (compute_mfu), and apply a causal-attention correction (adjusted_mfu).Extends
tests/test_utils.pywith focused unit tests that validate scaling behavior, tied vs untied embeddings, MoE expert-count deltas, and MFU formula correctness.Reviewed by Cursor Bugbot for commit 676a378. Bugbot is set up for automated code reviews on this repo. Configure here.