Skip to content

Preserve truncated tokens in BFD packing#4632

Merged
qgallouedec merged 11 commits into
mainfrom
keep-overlong-packing
Dec 16, 2025
Merged

Preserve truncated tokens in BFD packing#4632
qgallouedec merged 11 commits into
mainfrom
keep-overlong-packing

Conversation

@qgallouedec

@qgallouedec qgallouedec commented Dec 5, 2025

Copy link
Copy Markdown
Member

This PR updates the bfd packing strategy so that tokens beyond seq_length are not discarded.
Instead, truncated fragments are re-queued and packed like any other sequence, preventing unnecessary token loss.

Closes #4554

Untitled-2025-07-22-1600

Before/After

from datasets import Dataset
from trl import pack_dataset

examples = {
    "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10], [11]],
}
dataset = Dataset.from_dict(examples)
packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd")
print(packed_dataset[:])
- {'input_ids': [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7]], 'seq_lengths': [[4], [3, 1], [2]]}
+ {'input_ids': [[1, 2, 3, 4], [8, 9, 10, 5], [6, 7, 11]], 'seq_lengths': [[4], [3, 1], [2, 1]]}

Benchmark

TLDR: a bit slower, but still very fast

import random
import time
from datasets import Dataset
from trl.data_utils import pack_dataset

total_tokens = 10_000_000
seq_length = 2048  # packing target
min_seq_len, max_seq_len = 1024, 3072  # arbitrary input lengths

input_ids = []
tokens_left = total_tokens
while tokens_left > 0:
    n = min(tokens_left, random.randint(min_seq_len, max_seq_len))
    tokens_left -= n
    input_ids.append(list(range(n)))

dataset = Dataset.from_dict({"input_ids": input_ids})

start = time.perf_counter()
packed = pack_dataset(dataset, seq_length=seq_length)
elapsed = time.perf_counter() - start

print(f"Packed {total_tokens} tokens into {len(packed)} examples in {elapsed:.3f}s")
# Before: Packed 10000000 tokens into 4848 examples in 0.189s
# After:  Packed 10000000 tokens into 4952 examples in 0.255s

Important

This PR was mostly written using Codex. Based on my tests, it works. I think I understand the most of it, but I'm not behind most of the code changes.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec linked an issue Dec 5, 2025 that may be closed by this pull request
Comment thread trl/data_utils.py
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
column = pc.list_slice(column, 0, seq_length)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was truncating all list columns. We don't want to truncated anymore

Comment thread trl/data_utils.py
Comment on lines +650 to +651
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PyArrow, a ChunkedArray is one logical column made of multiple smaller arrays ("chunks") under the hood.
Combining chunks gives one continuous array (leaving it chunked would mean offsets restart in each piece). It allows the code to operate on a single contiguous chunk, which keeps offsets consistent and avoids chunk-boundary surprises.

@jiosephlee

Copy link
Copy Markdown

Looking forward to this PR!

@albertvillanova albertvillanova left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving this function so no tokens are discarded.

However, I think the new implementation might introduce a correctness issue for the usual use case with multiple list columns.

If I understand correctly, only the first list column is actually split into <= seq_length fragments. All other list columns are simply duplicated at the row level and later re-wrapped using the packed offsets. This would cause misalignment between columns.

A minimal repro: With

examples = {
  'input_ids_1': [
    [1, 2, 3, 4, 5], 
    [6, 7], 
    [8, 9, 10, 11], 
    [12]
  ],
  'input_ids_2': [
    [10, 20, 30, 40, 50], 
    [60, 70], 
    [80, 90, 100, 110], 
    [120]
  ]
}

the packed output with seq_length = 4 becomes:

{
  'input_ids_1': [
    [1, 2, 3, 4], 
    [8, 9, 10, 11], 
    [6, 7, 5, 12]
  ],
  'input_ids_2': [
    [10, 20, 30, 40], 
    [50, 80, 90, 100], 
    [110, 60, 70, 10]
  ],
  'seq_lengths': [
    [4], 
    [4], 
    [2, 1, 1]
  ]
}

Comment thread trl/data_utils.py Outdated
frag_slices.append((row_start + split_start, frag_len))
expanded_indices.append(row_idx)

# Rebuild list column with fragments and duplicate non-list columns accordingly.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate non-list columns accordingly

Does this function support non-list columns?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but in practice, we only get list columns:

trl/trl/trainer/sft_trainer.py

Lines 1088 to 1102 in d401a42

columns = ["input_ids"]
if "completion_mask" in get_dataset_column_names(dataset):
columns.append("completion_mask")
if "assistant_masks" in get_dataset_column_names(dataset):
columns.append("assistant_masks")
dataset = dataset.select_columns(columns)
# Shuffle the dataset before packing. When using wrapped packing, it's important to shuffle before
# packing as well to avoid correlations between sequences packed together.
if args.shuffle_dataset:
dataset = dataset.shuffle(seed=args.seed)
# Packing adds new column "seq_lengths" needed for document aware FlashAttention
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check in the function to ensure all columns are list

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trl/trl/data_utils.py

Lines 650 to 654 in ad82b13

for idx, column in enumerate(examples.columns):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
if not (pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type)):
raise TypeError("pack_dataset(bfd) requires all columns to be list-like.")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant: if non-list columns are not supported, then why "duplicating non-list columns accordingly"?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread tests/test_data_utils.py
seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]],

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, we never all packing with the attention mask + it doesn't add anything to test it here. So I suggest we remove it.

Comment thread tests/test_data_utils.py
Comment on lines +1040 to +1050
def test_with_overlong_two_coluns(self):
examples = {
"col1": [[1, -2, 3, -4, 5, -6], [7, -8, 9], [-10, 11, -12], [13, -14, 15, -16]],
"col2": [[-1, 2, -3, 4, -5, -6], [-7, 8, -9], [10, -11, 12], [-13, 14, -15, 16]],
}
dataset = Dataset.from_dict(examples)
seq_length = 4
expected_output = {
"col1": [[1, -2, 3, -4], [13, -14, 15, -16], [7, -8, 9], [-10, 11, -12], [5, -6]],
"col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]],
"seq_lengths": [[4], [4], [3], [3], [2]],

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the finding here, this new test ensure we have a consistent packing across columns

@albertvillanova albertvillanova left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a failing test: https://github.com/huggingface/trl/actions/runs/20238475748/job/58099567098?pr=4632

tests/test_data_utils.py::TestPackDatasetBfd::test_with_overlong_two_coluns

Comment thread tests/test_data_utils.py Outdated
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>

@albertvillanova albertvillanova left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

Comment thread trl/data_utils.py Outdated
qgallouedec and others added 2 commits December 16, 2025 13:09
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
@qgallouedec qgallouedec merged commit ec70ef2 into main Dec 16, 2025
10 of 11 checks passed
@qgallouedec qgallouedec deleted the keep-overlong-packing branch December 16, 2025 20:37
songhappy pushed a commit to songhappy/trl that referenced this pull request Apr 20, 2026
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Better packing of data with best-fit decrease strategy

4 participants