-
Notifications
You must be signed in to change notification settings - Fork 174
Description
🐛 Describe the bug
The MaxTokenBucketizer data pipe should probably be aware whether users will pad the resulting list of data to equal length. As the description gives an example in the audio domain, where padding is frequently used to create a batch of equal length sequences, users will probably run into edge cases where the current behavior will create OOM errors.
For example, assume a dataset of 5 sequences, with length [100, 10, 10, 10, 100], and the requirement that no batch has more than 150 tokens.
The desired behavior would be be 3 batches: [100], [10, 10, 10], and [100]. The current behavior leads to the following:
import torch as t
from torchdata.datapipes.iter import (
MaxTokenBucketizer,
IterableWrapper,
Mapper,
)
class Sample:
def __init__(self, length: int):
self.seq = t.ones((length,))
def __len__(self):
return self.seq.shape[0]
def __lt__(self, other):
len(self).__lt__(len(other))
def sample_to_tensor(bucket):
return [x.seq for x in bucket]
def pad_bucket(bucket):
return t.nn.utils.rnn.pad_sequence(bucket, batch_first=True)
data_lengths = [100, 10, 10, 10, 100]
data = [Sample(x) for x in data_lengths]
dp = IterableWrapper(data)
dp = MaxTokenBucketizer(dp, max_token_count=150)
dp = Mapper(dp, sample_to_tensor)
dp = Mapper(dp, pad_bucket)
for x in dp:
print(x.shape, f"num_tokens:{x.shape[0] * x.shape[1]}")prints:
torch.Size([4, 100]) num_tokens:400
torch.Size([1, 100]) num_tokens:100
Versions
PyTorch version: 1.12.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Arch Linux (x86_64)
GCC version: (GCC) 12.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.36
Python version: 3.10.0 (default, Sep 21 2022, 22:36:10) [GCC 12.2.0] (64-bit runtime)
Python platform: Linux-5.19.10-arch1-1-x86_64-with-glibc2.36
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070
Nvidia driver version: 515.76
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.3
[pip3] pytorch-lightning==1.7.7
[pip3] torch==1.12.1+cu113
[pip3] torchaudio==0.12.1+cu113
[pip3] torchdata==0.4.1
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.13.1+cu113
[conda] Could not collect