Skip to content

Commit bea0519

Browse files
ejguanfacebook-github-bot
authored andcommitted
[WIP][DataLoader] Implement BucketBatchIterableDataset (#51126)
Summary: Pull Request resolved: #51126 BucketBatch: Get a chunk of data as a bucket, and sort the bucket by the specified key, then batching. If sort key is not specified, directly use batchIterableDS.. 1. Implement BucketBatch for bucket sampler 2. Improve BatchDS tests Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D26209890 Pulled By: ejguan fbshipit-source-id: 8519e2e49da158b3fe32913c8f3cadfa6f3ff1fc
1 parent 14ee63f commit bea0519

4 files changed

Lines changed: 145 additions & 38 deletions

File tree

test/test_dataset.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import pickle
2+
import random
23
import tempfile
34
import warnings
45

56
import torch
67
from torch.testing._internal.common_utils import (TestCase, run_tests)
78
from torch.utils.data import IterableDataset, RandomSampler
89
from torch.utils.data.datasets import \
9-
(CallableIterableDataset, CollateIterableDataset, BatchIterableDataset,
10+
(CallableIterableDataset, CollateIterableDataset, BatchIterableDataset, BucketBatchIterableDataset,
1011
ListDirFilesIterableDataset, LoadFilesFromDiskIterableDataset, SamplerIterableDataset)
1112
from typing import List, Tuple, Dict, Any, Type
1213

@@ -146,46 +147,81 @@ def _collate_fn(batch):
146147
self.assertEqual(x, torch.tensor(y))
147148

148149
def test_batch_dataset(self):
149-
arrs = range(10)
150+
arrs = list(range(10))
150151
ds = IterDatasetWithLen(arrs)
151152
with self.assertRaises(AssertionError):
152153
batch_ds0 = BatchIterableDataset(ds, batch_size=0)
153154

154155
# Default not drop the last batch
155-
batch_ds1 = BatchIterableDataset(ds, batch_size=3)
156+
bs = 3
157+
batch_ds1 = BatchIterableDataset(ds, batch_size=bs)
156158
self.assertEqual(len(batch_ds1), 4)
157-
batch_iter = iter(batch_ds1)
158-
value = 0
159-
for i in range(len(batch_ds1)):
160-
batch = next(batch_iter)
161-
if i == 3:
162-
self.assertEqual(len(batch), 1)
163-
self.assertEqual(batch, [9])
164-
else:
165-
self.assertEqual(len(batch), 3)
166-
for x in batch:
167-
self.assertEqual(x, value)
168-
value += 1
159+
for i, batch in enumerate(batch_ds1):
160+
self.assertEqual(len(batch), 1 if i == 3 else bs)
161+
self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)])
169162

170163
# Drop the last batch
171-
batch_ds2 = BatchIterableDataset(ds, batch_size=3, drop_last=True)
172-
self.assertEqual(len(batch_ds2), 3)
173-
value = 0
174-
for batch in batch_ds2:
175-
self.assertEqual(len(batch), 3)
176-
for x in batch:
177-
self.assertEqual(x, value)
178-
value += 1
179-
180-
batch_ds3 = BatchIterableDataset(ds, batch_size=2)
181-
self.assertEqual(len(batch_ds3), 5)
182-
batch_ds4 = BatchIterableDataset(ds, batch_size=2, drop_last=True)
183-
self.assertEqual(len(batch_ds4), 5)
164+
bs = 4
165+
batch_ds2 = BatchIterableDataset(ds, batch_size=bs, drop_last=True)
166+
self.assertEqual(len(batch_ds2), 2)
167+
for i, batch in enumerate(batch_ds2):
168+
self.assertEqual(len(batch), bs)
169+
self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)])
170+
171+
ds_nl = IterDatasetWithoutLen(range(10))
172+
batch_ds_nl = BatchIterableDataset(ds_nl, batch_size=2)
173+
with self.assertRaises(NotImplementedError):
174+
len(batch_ds_nl)
184175

185-
ds_nolen = IterDatasetWithoutLen(arrs)
186-
batch_ds_nolen = BatchIterableDataset(ds_nolen, batch_size=5)
176+
def test_bucket_batch_dataset(self):
177+
ds = IterDatasetWithLen(range(20))
178+
with self.assertRaises(AssertionError):
179+
BucketBatchIterableDataset(ds, batch_size=0)
180+
181+
ds_nl = IterDatasetWithoutLen(range(20))
182+
bucket_ds_nl = BucketBatchIterableDataset(ds_nl, batch_size=7)
187183
with self.assertRaises(NotImplementedError):
188-
len(batch_ds_nolen)
184+
len(bucket_ds_nl)
185+
186+
# Test Bucket Batch without sort_key
187+
def _helper(**kwargs):
188+
arrs = list(range(100))
189+
random.shuffle(arrs)
190+
ds = IterDatasetWithLen(arrs)
191+
bucket_ds = BucketBatchIterableDataset(ds, **kwargs)
192+
if kwargs["sort_key"] is None:
193+
# BatchDataset as reference
194+
ref_ds = BatchIterableDataset(ds, batch_size=kwargs['batch_size'], drop_last=kwargs['drop_last'])
195+
for batch, rbatch in zip(bucket_ds, ref_ds):
196+
self.assertEqual(batch, rbatch)
197+
else:
198+
bucket_size = bucket_ds.bucket_size
199+
bucket_num = (len(ds) - 1) // bucket_size + 1
200+
it = iter(bucket_ds)
201+
for i in range(bucket_num):
202+
ref = sorted(arrs[i * bucket_size: (i + 1) * bucket_size])
203+
bucket: List = []
204+
while len(bucket) < len(ref):
205+
try:
206+
batch = next(it)
207+
bucket += batch
208+
# If drop last, stop in advance
209+
except StopIteration:
210+
break
211+
if len(bucket) != len(ref):
212+
ref = ref[:len(bucket)]
213+
# Sorted bucket
214+
self.assertEqual(bucket, ref)
215+
216+
_helper(batch_size=7, drop_last=False, sort_key=None)
217+
_helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=None)
218+
219+
# Test Bucket Batch with sort_key
220+
def _sort_fn(data):
221+
return data
222+
223+
_helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn)
224+
_helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn)
189225

190226
def test_sampler_dataset(self):
191227
arrs = range(10)

torch/utils/data/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from .dataset import IterableDataset as IterDataPipe
55
from .distributed import DistributedSampler
66
from .dataloader import DataLoader, _DatasetKind, get_worker_info
7-
from .datasets import (BatchIterableDataset, CallableIterableDataset, CollateIterableDataset, SamplerIterableDataset)
7+
from .datasets import (BatchIterableDataset, BucketBatchIterableDataset, CallableIterableDataset, CollateIterableDataset,
8+
SamplerIterableDataset)
89

910
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
1011
'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler',
1112
'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset',
1213
'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset',
1314
'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info',
14-
'BatchIterableDataset', 'CallableIterableDataset', 'CollateIterableDataset',
15-
'SamplerIterableDataset', 'IterDataPipe']
15+
'BatchIterableDataset', 'BucketBatchIterableDataset', 'CallableIterableDataset',
16+
'CollateIterableDataset', 'SamplerIterableDataset', 'IterDataPipe']
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from .batchdataset import BatchIterableDataset
1+
from .batchdataset import BatchIterableDataset, BucketBatchIterableDataset
22
from .callabledataset import CallableIterableDataset, CollateIterableDataset
33
from .samplerdataset import SamplerIterableDataset
44
from .listdirfilesdataset import ListDirFilesIterableDataset
55
from .loadfilesfromdiskdataset import LoadFilesFromDiskIterableDataset
66

7-
__all__ = ['BatchIterableDataset', 'CallableIterableDataset', 'CollateIterableDataset',
8-
'ListDirFilesIterableDataset', 'LoadFilesFromDiskIterableDataset', 'SamplerIterableDataset']
7+
__all__ = ['BatchIterableDataset', 'BucketBatchIterableDataset', 'CallableIterableDataset',
8+
'CollateIterableDataset', 'ListDirFilesIterableDataset', 'LoadFilesFromDiskIterableDataset',
9+
'SamplerIterableDataset']

torch/utils/data/datasets/batchdataset.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import warnings
12
from torch.utils.data import IterableDataset
2-
from typing import TypeVar, Optional, Iterator, List, Sized
3+
from typing import TypeVar, Optional, Iterator, List, Sized, Callable
34

45
T_co = TypeVar('T_co', covariant=True)
56

@@ -55,3 +56,71 @@ def __len__(self) -> int:
5556
self.length = (len(self.dataset) + self.batch_size - 1) // self.batch_size
5657
return self.length
5758
raise NotImplementedError
59+
60+
61+
class BucketBatchIterableDataset(IterableDataset[List[T_co]]):
62+
r""" :class:`BucketBatchIterableDataset`.
63+
64+
IterableDataset to create mini-batches of data from sorted bucket. An outer
65+
dimension will be added as `batch_size` if `drop_last` is set to `True`,
66+
or `length % batch_size` for the last batch if `drop_last` is set to `False`.
67+
args:
68+
dataset: IterableDataset being batched
69+
batch_size: The size of each batch
70+
drop_last: Option to drop the last batch if it's not full
71+
bucket_size_mul: The multiplier to specify the size of bucket
72+
sort_key: Callable to specify the comparison key for sorting within bucket
73+
"""
74+
dataset: IterableDataset[T_co]
75+
batch_size: int
76+
drop_last: bool
77+
bucket_size_mul: int
78+
sort_key: Optional[Callable]
79+
length: Optional[int]
80+
81+
def __init__(self,
82+
dataset: IterableDataset[T_co],
83+
*,
84+
batch_size: int,
85+
drop_last: bool = False,
86+
bucket_size_mul: int = 100,
87+
sort_key: Optional[Callable] = None,
88+
) -> None:
89+
assert batch_size > 0, "Batch size is required to be larger than 0!"
90+
super(BucketBatchIterableDataset, self).__init__()
91+
self.dataset = dataset
92+
self.batch_size = batch_size
93+
self.drop_last = drop_last
94+
self.bucket_size = batch_size * bucket_size_mul
95+
self.sort_key = sort_key
96+
if sort_key is not None and sort_key.__name__ == '<lambda>':
97+
warnings.warn("Lambda function is not supported for pickle, "
98+
"please use regular python function instead.")
99+
self.bucket_ds = BatchIterableDataset(dataset, batch_size=self.bucket_size, drop_last=False)
100+
self.length = None
101+
102+
def __iter__(self) -> Iterator[List[T_co]]:
103+
# Bucket without sorting remains same order, directly returns BatchDataset
104+
if self.sort_key is None:
105+
yield from BatchIterableDataset(self.dataset, batch_size=self.batch_size, drop_last=self.drop_last)
106+
else:
107+
bucket: List[T_co]
108+
batch: List[T_co] = []
109+
for bucket in self.bucket_ds:
110+
# In-place sort within bucket
111+
bucket.sort(key=self.sort_key)
112+
for start in range(0, len(bucket), self.batch_size):
113+
batch = bucket[start: start + self.batch_size]
114+
if len(batch) == self.batch_size or not self.drop_last:
115+
yield batch
116+
117+
def __len__(self) -> int:
118+
if self.length is not None:
119+
return self.length
120+
if isinstance(self.dataset, Sized) and len(self.dataset) >= 0:
121+
if self.drop_last:
122+
self.length = len(self.dataset) // self.batch_size
123+
else:
124+
self.length = (len(self.dataset) + self.batch_size - 1) // self.batch_size
125+
return self.length
126+
raise NotImplementedError

0 commit comments

Comments
 (0)