Skip to content

MaxTokenBucketizer does not guarentee max_token_count holds if padding is used #788

@nikvaessen

Description

@nikvaessen

🐛 Describe the bug

The MaxTokenBucketizer data pipe should probably be aware whether users will pad the resulting list of data to equal length. As the description gives an example in the audio domain, where padding is frequently used to create a batch of equal length sequences, users will probably run into edge cases where the current behavior will create OOM errors.

For example, assume a dataset of 5 sequences, with length [100, 10, 10, 10, 100], and the requirement that no batch has more than 150 tokens.

The desired behavior would be be 3 batches: [100], [10, 10, 10], and [100]. The current behavior leads to the following:

import torch as t

from torchdata.datapipes.iter import (
    MaxTokenBucketizer,
    IterableWrapper,
    Mapper,
)


class Sample:
    def __init__(self, length: int):
        self.seq = t.ones((length,))

    def __len__(self):
        return self.seq.shape[0]

    def __lt__(self, other):
        len(self).__lt__(len(other))


def sample_to_tensor(bucket):
    return [x.seq for x in bucket]


def pad_bucket(bucket):
    return t.nn.utils.rnn.pad_sequence(bucket, batch_first=True)


data_lengths = [100, 10, 10, 10, 100]
data = [Sample(x) for x in data_lengths]

dp = IterableWrapper(data)
dp = MaxTokenBucketizer(dp, max_token_count=150)
dp = Mapper(dp, sample_to_tensor)
dp = Mapper(dp, pad_bucket)

for x in dp:
    print(x.shape, f"num_tokens:{x.shape[0] * x.shape[1]}")

prints:

torch.Size([4, 100]) num_tokens:400
torch.Size([1, 100]) num_tokens:100

Versions

PyTorch version: 1.12.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 12.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.36

Python version: 3.10.0 (default, Sep 21 2022, 22:36:10) [GCC 12.2.0] (64-bit runtime)
Python platform: Linux-5.19.10-arch1-1-x86_64-with-glibc2.36
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070
Nvidia driver version: 515.76
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.3
[pip3] pytorch-lightning==1.7.7
[pip3] torch==1.12.1+cu113
[pip3] torchaudio==0.12.1+cu113
[pip3] torchdata==0.4.1
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.13.1+cu113
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions