Chunked cross-entropy loss for SFT (up to –50% VRAM)#5575
Conversation
Numerical equivalence:
|
| # | Config | Tests | Status | max |Δloss| | Δmem | Δtime | Trackio |
|---|---|---|---|---|---|---|---|
| 1 | Baseline (Qwen3-1.7B) |
tied embeddings, bf16, default settings | ✅ PASS | 1.22e-01 | -33.0% | -2.6% | link |
| 2 | Gradient accumulation (--gradient_accumulation_steps 8) |
num_items_in_batch plumbing across micro-batches |
✅ PASS | 1.32e-02 | -33.4% | -1.4% | link |
| 3 | Packing (--packing) |
variable-length packed sequences | ✅ PASS | 3.22e-02 | -24.9% | +13.4% | link |
| 4 | No gradient checkpointing (--no-gradient_checkpointing) |
chunked CE's per-chunk checkpoint without outer grad-ckpt |
✅ PASS | 1.22e-01 | -22.0% | -9.4% | link |
| 5 | Flash attention 2 (--attn_implementation kernels-community/flash-attn2) |
attention kernel orthogonality | ✅ PASS | 5.11e-02 | -33.0% | -9.5% | link |
| 6 | Long context (--max_length 8192 --batch_size 1) |
N_valid >> chunk_size |
✅ PASS | 1.33e-01 | +4.3% | +6.2% | link |
| 7 | Untied embeddings (microsoft/Phi-3-mini-4k-instruct) |
separate lm_head weight | ✅ PASS | 8.21e-02 | -0.0% | -0.0% | link |
| 8 | DDP × 4 GPUs (--gres=gpu:4) |
DistributedDataParallel wrapping, gradient all-reduce |
✅ PASS | 2.24e-02 | -28.5% | -23.2% | link |
| 9 | FSDP2 × 4 GPUs | sharded params + grads + optimizer | ✅ PASS | 9.70e-03 | -53.1% | +49.7% | link |
| 10 | FSDP2 × 4 GPUs, --fsdp_reshard_after_forward false |
same as 9. but keeps the un-wrapped lm_head resident |
✅ PASS | 7.80e-03 | -47.3% | -6.3% | link |
| 11 | fp32 (--no-bf16) |
tight tolerance — catches any real correctness drift | ✅ PASS | 9.50e-03 | -30.6% | -5.7% | link |
| 12 | ALST / Ulysses SP × 4 GPUs (sp_size=2 × dp_shard_size=2) |
shift_labels path under sequence parallelism |
✅ PASS | 1.50e-01 | -0.1% | -1.4% | link |
Note on row 9 (FSDP2 default): the +49.7% time cost comes from chunked CE triggering a per-chunk all-gather of lm_head.weight during backward — with chunk_size=256, that's ~32 redundant gathers per step. Setting --fsdp_reshard_after_forward false (row 10) keeps the (unwrapped) lm_head resident across forward/backward, eliminating the repeated gathers; memory savings drop slightly (-47% vs -53%) but time flips from +50% to ~neutral. SFTTrainer now emits a logger.warning when it detects loss_type='chunked_nll' + FSDP2 + reshard_after_forward=True so users don't hit this unknowingly.
They all look mostly like this (almost perfect overlap), check out the Trackio links for detailed per-step loss comparisons:
benchmark_chunked_nll.py
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# /// script
# dependencies = ["trl", "trackio"]
# ///
"""
Train one SFT run with a given `loss_type` and report peak memory + wall time.
Invoke twice (once with `--loss_type nll`, once with `--loss_type chunked_nll`) using the same
`--benchmark_id` to produce a comparable pair. Both runs log to the same trackio project under
different run names (`{benchmark_id}-nll` vs `{benchmark_id}-chunked_nll`). Running each loss
type as a separate process avoids cross-contamination from residual CUDA allocations.
"""
import argparse
import os
import time
import torch
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--loss_type", choices=["nll", "chunked_nll"], required=True)
parser.add_argument("--benchmark_id", required=True, help="Shared id for the nll + chunked_nll pair.")
parser.add_argument("--model", default="Qwen/Qwen3-1.7B")
parser.add_argument("--dataset", default="trl-lib/Capybara")
parser.add_argument("--max_steps", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--max_length", type=int, default=1024)
parser.add_argument("--bf16", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--gradient_checkpointing", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--packing", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--attn_implementation", default=None, help="e.g. 'sdpa', 'eager', 'flash_attention_2'")
parser.add_argument("--output_dir", default="/tmp/chunked_nll_bench")
args = parser.parse_args()
dataset = load_dataset(args.dataset)
# Keep only single-turn rows → prompt-completion format so `completion_only_loss` kicks in.
dataset = dataset.filter(lambda ex: len(ex["messages"]) == 2)
dataset = dataset.map(
lambda ex: {"prompt": [ex["messages"][0]], "completion": [ex["messages"][1]]},
remove_columns=["messages"],
)
training_args = SFTConfig(
output_dir=f"{args.output_dir}/{args.benchmark_id}/{args.loss_type}",
loss_type=args.loss_type,
run_name=f"{args.benchmark_id}-{args.loss_type}",
max_length=args.max_length,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
max_steps=args.max_steps,
logging_steps=1,
save_strategy="no",
report_to="trackio",
trackio_space_id=os.getenv("TRACKIO_SPACE_ID"),
seed=1234,
data_seed=1234,
bf16=args.bf16,
gradient_checkpointing=args.gradient_checkpointing,
packing=args.packing,
model_init_kwargs={"attn_implementation": args.attn_implementation} if args.attn_implementation else None,
)
trainer = SFTTrainer(model=args.model, args=training_args, train_dataset=dataset["train"])
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start = time.perf_counter()
trainer.train()
torch.cuda.synchronize()
train_time = time.perf_counter() - start
peak_gb = torch.cuda.max_memory_allocated() / 1024**3
if torch.distributed.is_initialized():
t = torch.tensor(peak_gb, device="cuda")
torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.MAX)
peak_gb = t.item()
if torch.distributed.get_rank() != 0:
raise SystemExit(0)
print(
f"benchmark_id={args.benchmark_id} loss_type={args.loss_type}"
f" peak={peak_gb:.2f} GB time={train_time:.2f} s"
)
if __name__ == "__main__":
main()benchmark_chunked_nll.slurm (single-GPU / DDP)
#!/bin/bash
#SBATCH --job-name=chunked-nll-bench
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --partition=hopper-prod
#SBATCH --time=2:00:00
#SBATCH --output=/fsx/qgallouedec/logs/%x-%j.out
#SBATCH --qos=normal
set -x
source ~/.bashrc
conda activate trl
echo "START TIME: $(date)"
cd /fsx/qgallouedec/trl
export TRACKIO_SPACE_ID="qgallouedec/chunked-nll-benchmark-2"
accelerate launch --main_process_port $((29500 + RANDOM % 1000)) benchmark_chunked_nll.py "$@"
echo "END TIME: $(date)"benchmark_chunked_nll_fsdp.slurm (FSDP2)
#!/bin/bash
#SBATCH --job-name=chunked-nll-bench-fsdp
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:4
#SBATCH --partition=hopper-prod
#SBATCH --time=2:00:00
#SBATCH --output=/fsx/qgallouedec/logs/%x-%j.out
#SBATCH --qos=normal
set -x
source ~/.bashrc
conda activate trl
echo "START TIME: $(date)"
cd /fsx/qgallouedec/trl
export TRACKIO_SPACE_ID="qgallouedec/chunked-nll-benchmark-2"
accelerate launch \
--main_process_port $((29500 + RANDOM % 1000)) \
--mixed_precision bf16 \
--use_fsdp \
--fsdp_version 2 \
--fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \
--fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer \
benchmark_chunked_nll.py "$@"
echo "END TIME: $(date)"|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c06f3416ef
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
lewtun
left a comment
There was a problem hiding this comment.
Very nice feature! Implementation-wise it looks good to me, but I'd like to see some small experiments done with:
- MoEs
- ZeRO-3 + Ulysses and FSDP2 + CP
to make sure we understand whether there's any material difference for these architectures and long context plugins. They can be fixed in follow-up PRs, but in that case we should add a note in the docs that they are not currently supported
|
|
||
| self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) | ||
|
|
||
| # Under FSDP2 with `reshard_after_forward=True` (accelerate's default), the chunked CE path triggers a |
There was a problem hiding this comment.
More of a question for accelerate, but is their default needed in general? If not, we could set reshard_after_forward=True in our example FSDP2 config too.
There was a problem hiding this comment.
Accelerate's FSDP2 default is reshard_after_forward=True, applied uniformly to every fully_shard() call. Interestingly, it's overriding PyTorch's None default (which would leave the root unit's params, including lm_head, resident). Any idea why @SunMarc?
In my understanding, in general, True is fine: the backward pass re-gathers each unit's params exactly once, overlappable with compute. It only bites when the same unsharded params are used in multiple forward passes within a step, which is what chunked_nll does (one lm_head.weight all-gather per chunk).
There was a problem hiding this comment.
I think this was more to copy FSDP setup, but FSDPv2 api is indeed a bit more flexible since you can fully_shard directly the modules you want. A quick fix to that and this will probably match what torchtitan do is to just set it to reshard_after_forward =None when sharding the remaining modules that we have.
One caveat is that usually we should reshard the embedding unless it is tied with the lm_head. But for now, let's just not reshard it and we can deal with this case later if needed. The new FSDPv2 impl from @3outeille should have better defaults.
There was a problem hiding this comment.
Way better, though is seems like reshard=False is still a faster, which is expected in my opinion
| Config | Δmem | Δtime |
|---|---|---|
| Row 9 — old default (reshard=True) | −53.1 % | +49.7 % |
| Row 10 — explicit reshard=False override | −47.3 % | −6.3 % |
| accelerate#4015 | −53.1 % | +16.4 % |
There was a problem hiding this comment.
Yes, that's there is a tradeoff between memory and communication. We can still advertise users to set it to False if they want faster fsdp.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
There are 2 total unresolved issues (including 1 from previous review).
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 3025899. Configure here.
| # wraps the patched forward. | ||
| if self._is_vlm: | ||
| raise NotImplementedError("`loss_type='chunked_nll'` is not supported for VLM models yet.") | ||
| if peft_config is not None or is_peft_model(model): |
There was a problem hiding this comment.
Nice to see this in trl. I don't understand why it doesn't work with PEFT though, is there anything we can do to enable it @qgallouedec?
There was a problem hiding this comment.
Probably nothing is blocking. But the PR was already big, so I wanted to separate the concerns.
As for VLM I will dedicate a PR to PEFT x chunked loss.
There was a problem hiding this comment.
In other words, it's not "not working with peft" it's just not tested, and I'll do it soon
There was a problem hiding this comment.
Ah nice, looking forward to it. After a first glance, I don't think anything should conflict with PEFT, but LMK if you find any issues.
- Note transformers #45433 (sonic-moe CuteDSL kernel integration) - Highlight TRL-side contribution to #45621 (wrapper-side masked_fill pair in grouped_mm_experts_forward) - Credit @qgallouedec on TRL #5575 (chunked CE) and explain why it is load-bearing for the long-context recipe Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
some questions: |
|
(1) yes, chunked and fused |
|
@qgallouedec Thanks for the quick response. Following-up on a few things: (1) On my own ablations, liger is showing a +14GB improvement (144GB vs 158GB) on B200s (length: 4096, FA2 padding-free) over chunked_nll on Qwen3-8B. Maybe this warrants further examination so we can inform users more accurately on when chunked_nll does better and does not? (3) Looking at the code now, it doesn't seem to do this https://github.com/linkedin/Liger-Kernel/blob/3bb3b3fae6d0b2356116034a7f0ee1dde0ea71ea/src/liger_kernel/ops/fused_linear_cross_entropy.py#L99 But it's my first pass of the code so I might have looked past something. |

What does this PR do?
Choosing the right chunk size -> 256
We care more about peak memory than wall time (the whole point of chunked CE is to fit workloads that wouldn't otherwise fit). So IMO "optimal" = the smallest
chunk_sizewhose time stays within 50% of the sweep minimum.Fortunately, the sweet spot is consistent across model size and sequence length:
chunk_size = 256wins in all 5 configs tested.Picking 512 or 1024 would trim wall time by only a few percent at the cost of meaningfully higher memory. Not worth it given the goal. Setting
_CHUNKED_LM_HEAD_CHUNK_SIZE = 256.Benchmarking script
eg:
Plotting script
Note
Medium Risk
Adds a new
loss_type='chunked_nll'path that monkey-patchesmodel.forwardand changes how loss/metrics are computed, which could affect training correctness or performance across model families and distributed configs. Guardrails and extensive numerical-equivalence tests reduce risk, but this touches the core SFT training loop.Overview
Adds a new SFT loss mode,
loss_type="chunked_nll", that reduces peak activation memory by avoiding full[batch×seq×vocab]logits: ignored-label tokens are dropped before thelm_headmatmul and the remaining tokens’ cross-entropy is computed in checkpointed chunks (default chunk size256).Implements this by patching the model’s
forwardto run the decoder directly and return loss + aggregatednum_correct_tokens/entropy_sum(and MoEaux_losswhenoutput_router_logits=True), while falling back to the original forward when labels aren’t provided to keep generation/eval behavior unchanged.Updates
SFTConfig/docs to document the new option and its constraints (not compatible with Liger, PEFT, or VLM; FSDP2 guidance), adds an FSDP2 performance warning, and introduces tests covering numerical equivalence (forward/backward), ignore-index edge cases, and end-to-end training withchunked_nll.Reviewed by Cursor Bugbot for commit d87972c. Bugbot is set up for automated code reviews on this repo. Configure here.