Preserve truncated tokens in BFD packing#4632
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if isinstance(column, pa.ChunkedArray): | ||
| column = column.combine_chunks() | ||
| if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): | ||
| column = pc.list_slice(column, 0, seq_length) |
There was a problem hiding this comment.
this was truncating all list columns. We don't want to truncated anymore
| if isinstance(column, pa.ChunkedArray): | ||
| column = column.combine_chunks() |
There was a problem hiding this comment.
In PyArrow, a ChunkedArray is one logical column made of multiple smaller arrays ("chunks") under the hood.
Combining chunks gives one continuous array (leaving it chunked would mean offsets restart in each piece). It allows the code to operate on a single contiguous chunk, which keeps offsets consistent and avoids chunk-boundary surprises.
|
Looking forward to this PR! |
There was a problem hiding this comment.
Thanks for improving this function so no tokens are discarded.
However, I think the new implementation might introduce a correctness issue for the usual use case with multiple list columns.
If I understand correctly, only the first list column is actually split into <= seq_length fragments. All other list columns are simply duplicated at the row level and later re-wrapped using the packed offsets. This would cause misalignment between columns.
A minimal repro: With
examples = {
'input_ids_1': [
[1, 2, 3, 4, 5],
[6, 7],
[8, 9, 10, 11],
[12]
],
'input_ids_2': [
[10, 20, 30, 40, 50],
[60, 70],
[80, 90, 100, 110],
[120]
]
}the packed output with seq_length = 4 becomes:
{
'input_ids_1': [
[1, 2, 3, 4],
[8, 9, 10, 11],
[6, 7, 5, 12]
],
'input_ids_2': [
[10, 20, 30, 40],
[50, 80, 90, 100],
[110, 60, 70, 10]
],
'seq_lengths': [
[4],
[4],
[2, 1, 1]
]
}| frag_slices.append((row_start + split_start, frag_len)) | ||
| expanded_indices.append(row_idx) | ||
|
|
||
| # Rebuild list column with fragments and duplicate non-list columns accordingly. |
There was a problem hiding this comment.
duplicate non-list columns accordingly
Does this function support non-list columns?
There was a problem hiding this comment.
No, but in practice, we only get list columns:
trl/trl/trainer/sft_trainer.py
Lines 1088 to 1102 in d401a42
There was a problem hiding this comment.
I added a check in the function to ensure all columns are list
There was a problem hiding this comment.
I meant: if non-list columns are not supported, then why "duplicating non-list columns accordingly"?
There was a problem hiding this comment.
| seq_length = 4 | ||
| expected_output = { | ||
| "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], | ||
| "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]], |
There was a problem hiding this comment.
In practice, we never all packing with the attention mask + it doesn't add anything to test it here. So I suggest we remove it.
| def test_with_overlong_two_coluns(self): | ||
| examples = { | ||
| "col1": [[1, -2, 3, -4, 5, -6], [7, -8, 9], [-10, 11, -12], [13, -14, 15, -16]], | ||
| "col2": [[-1, 2, -3, 4, -5, -6], [-7, 8, -9], [10, -11, 12], [-13, 14, -15, 16]], | ||
| } | ||
| dataset = Dataset.from_dict(examples) | ||
| seq_length = 4 | ||
| expected_output = { | ||
| "col1": [[1, -2, 3, -4], [13, -14, 15, -16], [7, -8, 9], [-10, 11, -12], [5, -6]], | ||
| "col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]], | ||
| "seq_lengths": [[4], [4], [3], [3], [2]], |
There was a problem hiding this comment.
Following the finding here, this new test ensure we have a consistent packing across columns
albertvillanova
left a comment
There was a problem hiding this comment.
There is a failing test: https://github.com/huggingface/trl/actions/runs/20238475748/job/58099567098?pr=4632
tests/test_data_utils.py::TestPackDatasetBfd::test_with_overlong_two_coluns
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
This PR updates the
bfdpacking strategy so that tokens beyondseq_lengthare not discarded.Instead, truncated fragments are re-queued and packed like any other sequence, preventing unnecessary token loss.
Closes #4554
Before/After
Benchmark
TLDR: a bit slower, but still very fast
Important
This PR was mostly written using Codex. Based on my tests, it works. I think I understand the most of it, but I'm not behind most of the code changes.