|
1 | 1 | import pickle |
| 2 | +import random |
2 | 3 | import tempfile |
3 | 4 | import warnings |
4 | 5 |
|
5 | 6 | import torch |
6 | 7 | from torch.testing._internal.common_utils import (TestCase, run_tests) |
7 | 8 | from torch.utils.data import IterableDataset, RandomSampler |
8 | 9 | from torch.utils.data.datasets import \ |
9 | | - (CallableIterableDataset, CollateIterableDataset, BatchIterableDataset, |
| 10 | + (CallableIterableDataset, CollateIterableDataset, BatchIterableDataset, BucketBatchIterableDataset, |
10 | 11 | ListDirFilesIterableDataset, LoadFilesFromDiskIterableDataset, SamplerIterableDataset) |
11 | 12 | from typing import List, Tuple, Dict, Any, Type |
12 | 13 |
|
@@ -146,46 +147,81 @@ def _collate_fn(batch): |
146 | 147 | self.assertEqual(x, torch.tensor(y)) |
147 | 148 |
|
148 | 149 | def test_batch_dataset(self): |
149 | | - arrs = range(10) |
| 150 | + arrs = list(range(10)) |
150 | 151 | ds = IterDatasetWithLen(arrs) |
151 | 152 | with self.assertRaises(AssertionError): |
152 | 153 | batch_ds0 = BatchIterableDataset(ds, batch_size=0) |
153 | 154 |
|
154 | 155 | # Default not drop the last batch |
155 | | - batch_ds1 = BatchIterableDataset(ds, batch_size=3) |
| 156 | + bs = 3 |
| 157 | + batch_ds1 = BatchIterableDataset(ds, batch_size=bs) |
156 | 158 | self.assertEqual(len(batch_ds1), 4) |
157 | | - batch_iter = iter(batch_ds1) |
158 | | - value = 0 |
159 | | - for i in range(len(batch_ds1)): |
160 | | - batch = next(batch_iter) |
161 | | - if i == 3: |
162 | | - self.assertEqual(len(batch), 1) |
163 | | - self.assertEqual(batch, [9]) |
164 | | - else: |
165 | | - self.assertEqual(len(batch), 3) |
166 | | - for x in batch: |
167 | | - self.assertEqual(x, value) |
168 | | - value += 1 |
| 159 | + for i, batch in enumerate(batch_ds1): |
| 160 | + self.assertEqual(len(batch), 1 if i == 3 else bs) |
| 161 | + self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)]) |
169 | 162 |
|
170 | 163 | # Drop the last batch |
171 | | - batch_ds2 = BatchIterableDataset(ds, batch_size=3, drop_last=True) |
172 | | - self.assertEqual(len(batch_ds2), 3) |
173 | | - value = 0 |
174 | | - for batch in batch_ds2: |
175 | | - self.assertEqual(len(batch), 3) |
176 | | - for x in batch: |
177 | | - self.assertEqual(x, value) |
178 | | - value += 1 |
179 | | - |
180 | | - batch_ds3 = BatchIterableDataset(ds, batch_size=2) |
181 | | - self.assertEqual(len(batch_ds3), 5) |
182 | | - batch_ds4 = BatchIterableDataset(ds, batch_size=2, drop_last=True) |
183 | | - self.assertEqual(len(batch_ds4), 5) |
| 164 | + bs = 4 |
| 165 | + batch_ds2 = BatchIterableDataset(ds, batch_size=bs, drop_last=True) |
| 166 | + self.assertEqual(len(batch_ds2), 2) |
| 167 | + for i, batch in enumerate(batch_ds2): |
| 168 | + self.assertEqual(len(batch), bs) |
| 169 | + self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)]) |
| 170 | + |
| 171 | + ds_nl = IterDatasetWithoutLen(range(10)) |
| 172 | + batch_ds_nl = BatchIterableDataset(ds_nl, batch_size=2) |
| 173 | + with self.assertRaises(NotImplementedError): |
| 174 | + len(batch_ds_nl) |
184 | 175 |
|
185 | | - ds_nolen = IterDatasetWithoutLen(arrs) |
186 | | - batch_ds_nolen = BatchIterableDataset(ds_nolen, batch_size=5) |
| 176 | + def test_bucket_batch_dataset(self): |
| 177 | + ds = IterDatasetWithLen(range(20)) |
| 178 | + with self.assertRaises(AssertionError): |
| 179 | + BucketBatchIterableDataset(ds, batch_size=0) |
| 180 | + |
| 181 | + ds_nl = IterDatasetWithoutLen(range(20)) |
| 182 | + bucket_ds_nl = BucketBatchIterableDataset(ds_nl, batch_size=7) |
187 | 183 | with self.assertRaises(NotImplementedError): |
188 | | - len(batch_ds_nolen) |
| 184 | + len(bucket_ds_nl) |
| 185 | + |
| 186 | + # Test Bucket Batch without sort_key |
| 187 | + def _helper(**kwargs): |
| 188 | + arrs = list(range(100)) |
| 189 | + random.shuffle(arrs) |
| 190 | + ds = IterDatasetWithLen(arrs) |
| 191 | + bucket_ds = BucketBatchIterableDataset(ds, **kwargs) |
| 192 | + if kwargs["sort_key"] is None: |
| 193 | + # BatchDataset as reference |
| 194 | + ref_ds = BatchIterableDataset(ds, batch_size=kwargs['batch_size'], drop_last=kwargs['drop_last']) |
| 195 | + for batch, rbatch in zip(bucket_ds, ref_ds): |
| 196 | + self.assertEqual(batch, rbatch) |
| 197 | + else: |
| 198 | + bucket_size = bucket_ds.bucket_size |
| 199 | + bucket_num = (len(ds) - 1) // bucket_size + 1 |
| 200 | + it = iter(bucket_ds) |
| 201 | + for i in range(bucket_num): |
| 202 | + ref = sorted(arrs[i * bucket_size: (i + 1) * bucket_size]) |
| 203 | + bucket: List = [] |
| 204 | + while len(bucket) < len(ref): |
| 205 | + try: |
| 206 | + batch = next(it) |
| 207 | + bucket += batch |
| 208 | + # If drop last, stop in advance |
| 209 | + except StopIteration: |
| 210 | + break |
| 211 | + if len(bucket) != len(ref): |
| 212 | + ref = ref[:len(bucket)] |
| 213 | + # Sorted bucket |
| 214 | + self.assertEqual(bucket, ref) |
| 215 | + |
| 216 | + _helper(batch_size=7, drop_last=False, sort_key=None) |
| 217 | + _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=None) |
| 218 | + |
| 219 | + # Test Bucket Batch with sort_key |
| 220 | + def _sort_fn(data): |
| 221 | + return data |
| 222 | + |
| 223 | + _helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn) |
| 224 | + _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn) |
189 | 225 |
|
190 | 226 | def test_sampler_dataset(self): |
191 | 227 | arrs = range(10) |
|
0 commit comments