Skip to content

Chunked cross-entropy loss for SFT (up to –50% VRAM)#5575

Merged
qgallouedec merged 32 commits into
mainfrom
chunked_ce
Apr 27, 2026
Merged

Chunked cross-entropy loss for SFT (up to –50% VRAM)#5575
qgallouedec merged 32 commits into
mainfrom
chunked_ce

Conversation

@qgallouedec

@qgallouedec qgallouedec commented Apr 17, 2026

Copy link
Copy Markdown
Member

What does this PR do?

sft_memory (1)

Choosing the right chunk size -> 256

We care more about peak memory than wall time (the whole point of chunked CE is to fit workloads that wouldn't otherwise fit). So IMO "optimal" = the smallest chunk_size whose time stays within 50% of the sweep minimum.

chunk_size_pareto

Fortunately, the sweet spot is consistent across model size and sequence length: chunk_size = 256 wins in all 5 configs tested.

shape ref memory knee memory ref time knee time
small model 8.29 GB 2.07 GB (-×4.0) 82 ms 151 ms (×1.84)
small model, long sequence 30.25 GB 2.14 GB (-×14.1) 323 ms 604 ms (×1.87)
medium model 9.70 GB 4.84 GB (-×2.0) 193 ms 337 ms (×1.74)
big model 12.04 GB 9.47 GB (-×1.3) 376 ms 658 ms (×1.75)
big model, long sequence 34.29 GB 9.84 GB (-×3.5) 1456 ms 2642 ms (×1.81)

Picking 512 or 1024 would trim wall time by only a few percent at the cost of meaningfully higher memory. Not worth it given the goal. Setting _CHUNKED_LM_HEAD_CHUNK_SIZE = 256.

Benchmarking script

eg:

python benchmark_chunked_nll_chunk_size.py \
    --hidden_size 1024 --n_valid 4096 --output /tmp/bench.json
# benchmark_chunked_nll_chunk_size.py
"""
Micro-benchmark `_chunked_cross_entropy_loss` across chunk sizes.

Runs forward + backward on synthetic `(hidden, weight, labels)` tensors and reports peak GPU
memory + wall time. The un-chunked baseline (full `[N_valid, V]` logits) is also timed as an
upper bound.
"""

import argparse
import json
import time

import torch
import torch.nn.functional as F

from trl.trainer.sft_trainer import _chunked_cross_entropy_loss


VOCAB_SIZE = 151936  # Qwen2.5 / Qwen3 family
CHUNK_SIZES = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]
ITERS = 10
DTYPE = torch.bfloat16


def make_inputs(n_valid, hidden_size):
    # Shift means `n_valid + 1` tokens yield exactly `n_valid` valid (non-ignored) positions.
    hidden = torch.randn(1, n_valid + 1, hidden_size, device="cuda", dtype=DTYPE, requires_grad=True)
    weight = torch.randn(VOCAB_SIZE, hidden_size, device="cuda", dtype=DTYPE, requires_grad=True)
    labels = torch.randint(0, VOCAB_SIZE, (1, n_valid + 1), device="cuda")
    return hidden, weight, labels


def bench(fn, n_valid, hidden_size):
    hidden, weight, labels = make_inputs(n_valid, hidden_size)
    try:
        fn(hidden, weight, labels).backward()  # warmup
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        start = time.perf_counter()
        for _ in range(ITERS):
            hidden.grad = weight.grad = None
            fn(hidden, weight, labels).backward()
        torch.cuda.synchronize()
        return torch.cuda.max_memory_allocated() / 1e6, (time.perf_counter() - start) * 1000 / ITERS
    except torch.cuda.OutOfMemoryError:
        return None, None
    finally:
        del hidden, weight, labels
        torch.cuda.empty_cache()


def reference_loss(h, w, lbl):
    """What the standard nll path does: materialize full logits."""
    shift_h = h[..., :-1, :].reshape(-1, h.size(-1))
    shift_l = lbl[..., 1:].reshape(-1)
    return F.cross_entropy(shift_h.float() @ w.float().t(), shift_l, ignore_index=-100)


def chunked_loss(chunk_size):
    return lambda h, w, lbl: _chunked_cross_entropy_loss(h, w, lbl, chunk_size=chunk_size)[0]


def main():
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("--hidden_size", type=int, required=True)
    parser.add_argument("--n_valid", type=int, required=True)
    parser.add_argument("--output", type=str, default=None)
    args = parser.parse_args()

    torch.manual_seed(0)
    print(f"N_valid={args.n_valid}, vocab={VOCAB_SIZE}, hidden={args.hidden_size}, dtype=bfloat16, iters={ITERS}")
    print(f"{'chunk_size':>12} {'peak_MB':>10} {'time_ms':>10}")

    ref = None
    peak, ms = bench(reference_loss, args.n_valid, args.hidden_size)
    if peak is None:
        print(f"{'reference':>12} {'OOM':>10} {'—':>10}")
    else:
        print(f"{'reference':>12} {peak:>10.1f} {ms:>10.2f}")
        ref = {"peak_mb": peak, "time_ms": ms}

    results = []
    for cs in CHUNK_SIZES:
        if cs > args.n_valid:
            continue
        peak, ms = bench(chunked_loss(cs), args.n_valid, args.hidden_size)
        if peak is None:
            print(f"{cs:>12} {'OOM':>10} {'—':>10}")
            results.append({"chunk_size": cs, "oom": True})
        else:
            print(f"{cs:>12} {peak:>10.1f} {ms:>10.2f}")
            results.append({"chunk_size": cs, "peak_mb": peak, "time_ms": ms})

    if args.output:
        shape = {"vocab_size": VOCAB_SIZE, "hidden_size": args.hidden_size, "n_valid": args.n_valid, "dtype": "bfloat16"}
        payload = {"shape": shape, "iters": ITERS, "reference": ref, "results": results}
        with open(args.output, "w") as f:
            json.dump(payload, f, indent=2)
        print(f"\nwrote {args.output}")


if __name__ == "__main__":
    main()
Plotting script
# /// script
# dependencies = ["matplotlib", "numpy"]
# ///
import argparse
import json
import pathlib

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D


def find_knee(chunk_sizes, time_ms, tol=0.50):
    """Smallest chunk_size whose time is within `tol` of the minimum (memory-priority)."""
    threshold = time_ms.min() * (1 + tol)
    for cs, t in zip(chunk_sizes, time_ms):
        if t <= threshold:
            return int(cs), float(t)
    return int(chunk_sizes[-1]), float(time_ms[-1])


def describe(shape, all_shapes):
    """Human-readable label with 'small/medium/big model' and 'long sequence' tags."""
    h, n = shape["hidden_size"], shape["n_valid"]
    hs = {s["hidden_size"] for s in all_shapes}
    ns = {s["n_valid"] for s in all_shapes}
    tags = []
    if len(hs) > 1:
        tags.append({max(hs): "big model", min(hs): "small model"}.get(h, "medium model"))
    if len(ns) > 1 and n == max(ns):
        tags.append("long sequence")
    return f"H={h}, N_valid={n}" + (f"  ({', '.join(tags)})" if tags else "")


def load_runs(dir_path):
    datas = []
    for p in sorted(pathlib.Path(dir_path).glob("*.json")):
        data = json.loads(p.read_text())
        rows = [r for r in data["results"] if not r.get("oom")]
        if rows:
            datas.append((data, rows))
    datas.sort(key=lambda d: (d[0]["shape"]["hidden_size"], d[0]["shape"]["n_valid"]))
    shapes = [d[0]["shape"] for d in datas]
    return [
        {
            "label": describe(data["shape"], shapes),
            "reference": data.get("reference"),
            "chunk_sizes": np.array([r["chunk_size"] for r in rows]),
            "peak_gb": np.array([r["peak_mb"] for r in rows]) / 1024,
            "time_ms": np.array([r["time_ms"] for r in rows]),
        }
        for data, rows in datas
    ]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", default="/tmp/chunk_size_bench")
    parser.add_argument("--out", default="chunk_size_pareto.png")
    args = parser.parse_args()

    runs = load_runs(args.dir)
    if not runs:
        raise SystemExit(f"No JSON files found in {args.dir}")

    colors = [plt.get_cmap("tab10")(i) for i in range(len(runs))]
    fig, ax = plt.subplots(figsize=(9.5, 7))
    fig.suptitle("Chunked CE: memory ↔ time trade-off", fontsize=14)

    for r, c in zip(runs, colors):
        ax.plot(r["time_ms"], r["peak_gb"], "-o", color=c, markersize=6, linewidth=2, label=r["label"])
        for cs, mem, t in zip(r["chunk_sizes"], r["peak_gb"], r["time_ms"]):
            ax.annotate(str(int(cs)), (t, mem), textcoords="offset points", xytext=(6, 4), fontsize=8, color=c)

        knee_cs, knee_t = find_knee(r["chunk_sizes"], r["time_ms"])
        knee_mem = r["peak_gb"][r["chunk_sizes"] == knee_cs][0]
        ax.scatter([knee_t], [knee_mem], color=c, marker="*", s=320, edgecolor="black", linewidth=1.2, zorder=6)

        if r["reference"] is not None:
            ref_t, ref_mem = r["reference"]["time_ms"], r["reference"]["peak_mb"] / 1024
            ax.scatter([ref_t], [ref_mem], color=c, marker="s", s=100, edgecolor="black", linewidth=1, zorder=5)
            ax.annotate("ref", (ref_t, ref_mem), textcoords="offset points", xytext=(8, -3),
                        fontsize=9, fontweight="bold", color=c)

    ax.set(xscale="log", yscale="log", xlabel="time per step (ms)", ylabel="peak memory (GB)")
    ax.set_yscale("log", base=2)
    ax.grid(True, which="both", alpha=0.3)

    shape_handles = [Line2D([0], [0], marker="o", linestyle="-", color=c, label=r["label"])
                     for r, c in zip(runs, colors)]
    glyph_handles = [
        Line2D([0], [0], marker="*", color="gray", linestyle="", markeredgecolor="black", markersize=14,
               label="knee (within 50% of min time)"),
        Line2D([0], [0], marker="s", color="gray", linestyle="", markeredgecolor="black", markersize=9,
               label="un-chunked reference"),
    ]
    ax.add_artist(ax.legend(handles=shape_handles, loc="upper left", fontsize=10, title="shape"))
    ax.legend(handles=glyph_handles, loc="upper right", fontsize=10, title="markers")

    print("\nknee per shape:")
    for r in runs:
        knee_cs, knee_t = find_knee(r["chunk_sizes"], r["time_ms"])
        knee_mem = r["peak_gb"][r["chunk_sizes"] == knee_cs][0]
        print(f"  {r['label']}  →  chunk_size = {knee_cs}  ({knee_mem:.2f} GB, {knee_t:.1f} ms)")

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(args.out, dpi=140, bbox_inches="tight")
    print(f"\nwrote {args.out}")


if __name__ == "__main__":
    main()

Note

Medium Risk
Adds a new loss_type='chunked_nll' path that monkey-patches model.forward and changes how loss/metrics are computed, which could affect training correctness or performance across model families and distributed configs. Guardrails and extensive numerical-equivalence tests reduce risk, but this touches the core SFT training loop.

Overview
Adds a new SFT loss mode, loss_type="chunked_nll", that reduces peak activation memory by avoiding full [batch×seq×vocab] logits: ignored-label tokens are dropped before the lm_head matmul and the remaining tokens’ cross-entropy is computed in checkpointed chunks (default chunk size 256).

Implements this by patching the model’s forward to run the decoder directly and return loss + aggregated num_correct_tokens/entropy_sum (and MoE aux_loss when output_router_logits=True), while falling back to the original forward when labels aren’t provided to keep generation/eval behavior unchanged.

Updates SFTConfig/docs to document the new option and its constraints (not compatible with Liger, PEFT, or VLM; FSDP2 guidance), adds an FSDP2 performance warning, and introduces tests covering numerical equivalence (forward/backward), ignore-index edge cases, and end-to-end training with chunked_nll.

Reviewed by Cursor Bugbot for commit d87972c. Bugbot is set up for automated code reviews on this repo. Configure here.

@qgallouedec

qgallouedec commented Apr 20, 2026

Copy link
Copy Markdown
Member Author

Numerical equivalence: chunked_nll vs nll under realistic SFT configs

End-to-end SFT training (bf16), two sequential runs from the same seed, one with loss_type="nll" and one with loss_type="chunked_nll", comparing per-step losses over 100 steps.

At 100 steps on Capybara + Qwen3-1.7B, the bf16 noise floor produces occasional outlier batches where the two paths' accumulated weight drift manifests as ~1e-1 loss spikes (the median step diff is ~2e-3; max is dominated by 1–2 outliers). Tolerance of 5e-2 catches the median behavior; the max excursion is ~1.2e-1 for baseline-like configs and is not specific to chunked_nll.

Status

# Config Tests Status max |Δloss| Δmem Δtime Trackio
1 Baseline (Qwen3-1.7B) tied embeddings, bf16, default settings ✅ PASS 1.22e-01 -33.0% -2.6% link
2 Gradient accumulation (--gradient_accumulation_steps 8) num_items_in_batch plumbing across micro-batches ✅ PASS 1.32e-02 -33.4% -1.4% link
3 Packing (--packing) variable-length packed sequences ✅ PASS 3.22e-02 -24.9% +13.4% link
4 No gradient checkpointing (--no-gradient_checkpointing) chunked CE's per-chunk checkpoint without outer grad-ckpt ✅ PASS 1.22e-01 -22.0% -9.4% link
5 Flash attention 2 (--attn_implementation kernels-community/flash-attn2) attention kernel orthogonality ✅ PASS 5.11e-02 -33.0% -9.5% link
6 Long context (--max_length 8192 --batch_size 1) N_valid >> chunk_size ✅ PASS 1.33e-01 +4.3% +6.2% link
7 Untied embeddings (microsoft/Phi-3-mini-4k-instruct) separate lm_head weight ✅ PASS 8.21e-02 -0.0% -0.0% link
8 DDP × 4 GPUs (--gres=gpu:4) DistributedDataParallel wrapping, gradient all-reduce ✅ PASS 2.24e-02 -28.5% -23.2% link
9 FSDP2 × 4 GPUs sharded params + grads + optimizer ✅ PASS 9.70e-03 -53.1% +49.7% link
10 FSDP2 × 4 GPUs, --fsdp_reshard_after_forward false same as 9. but keeps the un-wrapped lm_head resident ✅ PASS 7.80e-03 -47.3% -6.3% link
11 fp32 (--no-bf16) tight tolerance — catches any real correctness drift ✅ PASS 9.50e-03 -30.6% -5.7% link
12 ALST / Ulysses SP × 4 GPUs (sp_size=2 × dp_shard_size=2) shift_labels path under sequence parallelism ✅ PASS 1.50e-01 -0.1% -1.4% link

Note on row 9 (FSDP2 default): the +49.7% time cost comes from chunked CE triggering a per-chunk all-gather of lm_head.weight during backward — with chunk_size=256, that's ~32 redundant gathers per step. Setting --fsdp_reshard_after_forward false (row 10) keeps the (unwrapped) lm_head resident across forward/backward, eliminating the repeated gathers; memory savings drop slightly (-47% vs -53%) but time flips from +50% to ~neutral. SFTTrainer now emits a logger.warning when it detects loss_type='chunked_nll' + FSDP2 + reshard_after_forward=True so users don't hit this unknowingly.

They all look mostly like this (almost perfect overlap), check out the Trackio links for detailed per-step loss comparisons:

Screenshot 2026-04-20 at 11 16 34 PM
benchmark_chunked_nll.py
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0

# /// script
# dependencies = ["trl", "trackio"]
# ///

"""
Train one SFT run with a given `loss_type` and report peak memory + wall time.

Invoke twice (once with `--loss_type nll`, once with `--loss_type chunked_nll`) using the same
`--benchmark_id` to produce a comparable pair. Both runs log to the same trackio project under
different run names (`{benchmark_id}-nll` vs `{benchmark_id}-chunked_nll`). Running each loss
type as a separate process avoids cross-contamination from residual CUDA allocations.
"""

import argparse
import os
import time

import torch
from datasets import load_dataset

from trl import SFTConfig, SFTTrainer


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("--loss_type", choices=["nll", "chunked_nll"], required=True)
    parser.add_argument("--benchmark_id", required=True, help="Shared id for the nll + chunked_nll pair.")
    parser.add_argument("--model", default="Qwen/Qwen3-1.7B")
    parser.add_argument("--dataset", default="trl-lib/Capybara")
    parser.add_argument("--max_steps", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--max_length", type=int, default=1024)
    parser.add_argument("--bf16", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--gradient_checkpointing", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--packing", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--attn_implementation", default=None, help="e.g. 'sdpa', 'eager', 'flash_attention_2'")
    parser.add_argument("--output_dir", default="/tmp/chunked_nll_bench")
    args = parser.parse_args()

    dataset = load_dataset(args.dataset)
    # Keep only single-turn rows → prompt-completion format so `completion_only_loss` kicks in.
    dataset = dataset.filter(lambda ex: len(ex["messages"]) == 2)
    dataset = dataset.map(
        lambda ex: {"prompt": [ex["messages"][0]], "completion": [ex["messages"][1]]},
        remove_columns=["messages"],
    )

    training_args = SFTConfig(
        output_dir=f"{args.output_dir}/{args.benchmark_id}/{args.loss_type}",
        loss_type=args.loss_type,
        run_name=f"{args.benchmark_id}-{args.loss_type}",
        max_length=args.max_length,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_steps=args.max_steps,
        logging_steps=1,
        save_strategy="no",
        report_to="trackio",
        trackio_space_id=os.getenv("TRACKIO_SPACE_ID"),
        seed=1234,
        data_seed=1234,
        bf16=args.bf16,
        gradient_checkpointing=args.gradient_checkpointing,
        packing=args.packing,
        model_init_kwargs={"attn_implementation": args.attn_implementation} if args.attn_implementation else None,
    )

    trainer = SFTTrainer(model=args.model, args=training_args, train_dataset=dataset["train"])

    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    start = time.perf_counter()
    trainer.train()
    torch.cuda.synchronize()
    train_time = time.perf_counter() - start

    peak_gb = torch.cuda.max_memory_allocated() / 1024**3
    if torch.distributed.is_initialized():
        t = torch.tensor(peak_gb, device="cuda")
        torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.MAX)
        peak_gb = t.item()
        if torch.distributed.get_rank() != 0:
            raise SystemExit(0)

    print(
        f"benchmark_id={args.benchmark_id} loss_type={args.loss_type}"
        f" peak={peak_gb:.2f} GB time={train_time:.2f} s"
    )


if __name__ == "__main__":
    main()
benchmark_chunked_nll.slurm (single-GPU / DDP)
#!/bin/bash
#SBATCH --job-name=chunked-nll-bench
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --partition=hopper-prod
#SBATCH --time=2:00:00
#SBATCH --output=/fsx/qgallouedec/logs/%x-%j.out
#SBATCH --qos=normal

set -x
source ~/.bashrc
conda activate trl
echo "START TIME: $(date)"
cd /fsx/qgallouedec/trl
export TRACKIO_SPACE_ID="qgallouedec/chunked-nll-benchmark-2"

accelerate launch --main_process_port $((29500 + RANDOM % 1000)) benchmark_chunked_nll.py "$@"

echo "END TIME: $(date)"
benchmark_chunked_nll_fsdp.slurm (FSDP2)
#!/bin/bash
#SBATCH --job-name=chunked-nll-bench-fsdp
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:4
#SBATCH --partition=hopper-prod
#SBATCH --time=2:00:00
#SBATCH --output=/fsx/qgallouedec/logs/%x-%j.out
#SBATCH --qos=normal

set -x
source ~/.bashrc
conda activate trl
echo "START TIME: $(date)"
cd /fsx/qgallouedec/trl
export TRACKIO_SPACE_ID="qgallouedec/chunked-nll-benchmark-2"

accelerate launch \
    --main_process_port $((29500 + RANDOM % 1000)) \
    --mixed_precision bf16 \
    --use_fsdp \
    --fsdp_version 2 \
    --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \
    --fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer \
    benchmark_chunked_nll.py "$@"

echo "END TIME: $(date)"

@qgallouedec qgallouedec changed the title Chunked Cross-Entropy Chunked Cross-Entropy: Up to 50% reduced VRAM Apr 21, 2026
@qgallouedec qgallouedec marked this pull request as ready for review April 21, 2026 03:23
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title Chunked Cross-Entropy: Up to 50% reduced VRAM Chunked cross-entropy loss for SFT (up to –50% VRAM) Apr 21, 2026

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c06f3416ef

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread trl/trainer/sft_trainer.py
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_trainer.py

@lewtun lewtun left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice feature! Implementation-wise it looks good to me, but I'd like to see some small experiments done with:

  • MoEs
  • ZeRO-3 + Ulysses and FSDP2 + CP

to make sure we understand whether there's any material difference for these architectures and long context plugins. They can be fixed in follow-up PRs, but in that case we should add a note in the docs that they are not currently supported

Comment thread docs/source/reducing_memory_usage.md
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_config.py Outdated
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_trainer.py

self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)

# Under FSDP2 with `reshard_after_forward=True` (accelerate's default), the chunked CE path triggers a

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of a question for accelerate, but is their default needed in general? If not, we could set reshard_after_forward=True in our example FSDP2 config too.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accelerate's FSDP2 default is reshard_after_forward=True, applied uniformly to every fully_shard() call. Interestingly, it's overriding PyTorch's None default (which would leave the root unit's params, including lm_head, resident). Any idea why @SunMarc?

In my understanding, in general, True is fine: the backward pass re-gathers each unit's params exactly once, overlappable with compute. It only bites when the same unsharded params are used in multiple forward passes within a step, which is what chunked_nll does (one lm_head.weight all-gather per chunk).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was more to copy FSDP setup, but FSDPv2 api is indeed a bit more flexible since you can fully_shard directly the modules you want. A quick fix to that and this will probably match what torchtitan do is to just set it to reshard_after_forward =None when sharding the remaining modules that we have.

One caveat is that usually we should reshard the embedding unless it is tied with the lm_head. But for now, let's just not reshard it and we can deal with this case later if needed. The new FSDPv2 impl from @3outeille should have better defaults.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try this ? huggingface/accelerate#4015

@qgallouedec qgallouedec Apr 22, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Way better, though is seems like reshard=False is still a faster, which is expected in my opinion

Config Δmem Δtime
Row 9 — old default (reshard=True) −53.1 % +49.7 %
Row 10 — explicit reshard=False override −47.3 % −6.3 %
accelerate#4015 −53.1 % +16.4 %

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's there is a tradeoff between memory and communication. We can still advertise users to set it to False if they want faster fsdp.

Comment thread docs/source/reducing_memory_usage.md
Comment thread trl/trainer/sft_trainer.py
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_config.py Outdated
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_trainer.py Outdated
Comment thread trl/trainer/sft_trainer.py
Comment thread trl/trainer/sft_trainer.py
Comment thread trl/trainer/sft_trainer.py Outdated

@albertvillanova albertvillanova left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

Comment thread trl/trainer/sft_trainer.py

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 2 total unresolved issues (including 1 from previous review).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 3025899. Configure here.

Comment thread trl/trainer/sft_trainer.py
@qgallouedec qgallouedec merged commit 9bcf729 into main Apr 27, 2026
8 of 13 checks passed
@qgallouedec qgallouedec deleted the chunked_ce branch April 27, 2026 20:09
# wraps the patched forward.
if self._is_vlm:
raise NotImplementedError("`loss_type='chunked_nll'` is not supported for VLM models yet.")
if peft_config is not None or is_peft_model(model):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to see this in trl. I don't understand why it doesn't work with PEFT though, is there anything we can do to enable it @qgallouedec?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably nothing is blocking. But the PR was already big, so I wanted to separate the concerns.

As for VLM I will dedicate a PR to PEFT x chunked loss.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words, it's not "not working with peft" it's just not tested, and I'll do it soon

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice, looking forward to it. After a first glance, I don't think anything should conflict with PEFT, but LMK if you find any issues.

AmineDiro added a commit that referenced this pull request May 6, 2026
- Note transformers #45433 (sonic-moe CuteDSL kernel integration)
- Highlight TRL-side contribution to #45621 (wrapper-side masked_fill
  pair in grouped_mm_experts_forward)
- Credit @qgallouedec on TRL #5575 (chunked CE) and explain why it is
  load-bearing for the long-context recipe

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jiosephlee

Copy link
Copy Markdown

some questions:
(1) is the liger baseline the fused variant?
(2) since the optimization comes from dropping ignored labels, I imagine the benefits are highly dependent on the ratio between masked and unmasked tokens?
(3) wouldn't this approach be compatible with liger kernels which optimize the CE (or lm_head + CE for fused) kernels?

@qgallouedec

Copy link
Copy Markdown
Member Author

(1) yes, chunked and fused
(2) partially yes. actually, two sources for the optimization: 1. chunking reduces the peak (even with no mask token) and 2. no forward / backward on masked tokens
(3) I think this is already integrated in liger. the reason to have it here is for the user to benefit from this optimization without having to install an additional dep

@jiosephlee

Copy link
Copy Markdown

@qgallouedec Thanks for the quick response. Following-up on a few things:

(1) On my own ablations, liger is showing a +14GB improvement (144GB vs 158GB) on B200s (length: 4096, FA2 padding-free) over chunked_nll on Qwen3-8B. Maybe this warrants further examination so we can inform users more accurately on when chunked_nll does better and does not?

(3) Looking at the code now, it doesn't seem to do this https://github.com/linkedin/Liger-Kernel/blob/3bb3b3fae6d0b2356116034a7f0ee1dde0ea71ea/src/liger_kernel/ops/fused_linear_cross_entropy.py#L99

But it's my first pass of the code so I might have looked past something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants