Skip to content

Commit 1d5a9a1

Browse files
Skylion007pytorchmergebot
authored andcommitted
[Easy][BE]: remove itertools.accumulate Python 2 shim and apply UFMT (#116192)
Removes an unnecessary duplicated utility functions and just have it rely on itertools. Since the file is low traffic, I also added the modified files to UFMT'd files and formatted them. Pull Request resolved: #116192 Approved by: https://github.com/malfet
1 parent 602abf6 commit 1d5a9a1

3 files changed

Lines changed: 59 additions & 45 deletions

File tree

.lintrunner.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2544,7 +2544,6 @@ exclude_patterns = [
25442544
'torch/utils/data/datapipes/utils/common.py',
25452545
'torch/utils/data/datapipes/utils/decoder.py',
25462546
'torch/utils/data/datapipes/utils/snapshot.py',
2547-
'torch/utils/data/dataset.py',
25482547
'torch/utils/data/distributed.py',
25492548
'torch/utils/data/graph.py',
25502549
'torch/utils/data/graph_settings.py',

torch/_utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -493,22 +493,6 @@ def _import_dotted_name(name):
493493
return obj
494494

495495

496-
# Taken from python 3.5 docs
497-
def _accumulate(iterable, fn=lambda x, y: x + y):
498-
"Return running totals"
499-
# _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
500-
# _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
501-
it = iter(iterable)
502-
try:
503-
total = next(it)
504-
except StopIteration:
505-
return
506-
yield total
507-
for element in it:
508-
total = fn(total, element)
509-
yield total
510-
511-
512496
def _flatten_dense_tensors(tensors):
513497
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
514498
same dense type.

torch/utils/data/dataset.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import bisect
2-
import warnings
2+
import itertools
33
import math
4+
import warnings
45
from typing import (
6+
cast,
7+
Dict,
58
Generic,
69
Iterable,
710
List,
@@ -10,12 +13,10 @@
1013
Tuple,
1114
TypeVar,
1215
Union,
13-
Dict
1416
)
1517

1618
# No 'default_generator' in torch/__init__.pyi
1719
from torch import default_generator, randperm
18-
from torch._utils import _accumulate
1920

2021
from ... import Generator, Tensor
2122

@@ -30,11 +31,11 @@
3031
"random_split",
3132
]
3233

33-
T_co = TypeVar('T_co', covariant=True)
34-
T = TypeVar('T')
34+
T_co = TypeVar("T_co", covariant=True)
35+
T = TypeVar("T")
3536
T_dict = Dict[str, T_co]
3637
T_tuple = Tuple[T_co, ...]
37-
T_stack = TypeVar('T_stack', T_tuple, T_dict)
38+
T_stack = TypeVar("T_stack", T_tuple, T_dict)
3839

3940

4041
class Dataset(Generic[T_co]):
@@ -63,7 +64,7 @@ def __getitem__(self, index) -> T_co:
6364
# Not implemented to prevent false-positives in fetcher check in
6465
# torch.utils.data._utils.fetch._MapDatasetFetcher
6566

66-
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
67+
def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
6768
return ConcatDataset([self, other])
6869

6970
# No `def __len__(self)` default?
@@ -199,7 +200,9 @@ class TensorDataset(Dataset[Tuple[Tensor, ...]]):
199200
tensors: Tuple[Tensor, ...]
200201

201202
def __init__(self, *tensors: Tensor) -> None:
202-
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
203+
assert all(
204+
tensors[0].size(0) == tensor.size(0) for tensor in tensors
205+
), "Size mismatch between tensors"
203206
self.tensors = tensors
204207

205208
def __getitem__(self, index):
@@ -233,8 +236,10 @@ class StackDataset(Dataset[T_stack]):
233236
def __init__(self, *args: Dataset[T_co], **kwargs: Dataset[T_co]) -> None:
234237
if args:
235238
if kwargs:
236-
raise ValueError("Supported either ``tuple``- (via ``args``) or"
237-
"``dict``- (via ``kwargs``) like input/output, but both types are given.")
239+
raise ValueError(
240+
"Supported either ``tuple``- (via ``args``) or"
241+
"``dict``- (via ``kwargs``) like input/output, but both types are given."
242+
)
238243
self._length = len(args[0]) # type: ignore[arg-type]
239244
if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type]
240245
raise ValueError("Size mismatch between datasets")
@@ -261,8 +266,10 @@ def __getitems__(self, indices: list):
261266
if callable(getattr(dataset, "__getitems__", None)):
262267
items = dataset.__getitems__(indices) # type: ignore[attr-defined]
263268
if len(items) != len(indices):
264-
raise ValueError("Nested dataset's output size mismatch."
265-
f" Expected {len(indices)}, got {len(items)}")
269+
raise ValueError(
270+
"Nested dataset's output size mismatch."
271+
f" Expected {len(indices)}, got {len(items)}"
272+
)
266273
for data, d_sample in zip(items, dict_batch):
267274
d_sample[k] = data
268275
else:
@@ -276,8 +283,10 @@ def __getitems__(self, indices: list):
276283
if callable(getattr(dataset, "__getitems__", None)):
277284
items = dataset.__getitems__(indices) # type: ignore[attr-defined]
278285
if len(items) != len(indices):
279-
raise ValueError("Nested dataset's output size mismatch."
280-
f" Expected {len(indices)}, got {len(items)}")
286+
raise ValueError(
287+
"Nested dataset's output size mismatch."
288+
f" Expected {len(indices)}, got {len(items)}"
289+
)
281290
for data, t_sample in zip(items, list_batch):
282291
t_sample.append(data)
283292
else:
@@ -314,9 +323,11 @@ def cumsum(sequence):
314323
def __init__(self, datasets: Iterable[Dataset]) -> None:
315324
super().__init__()
316325
self.datasets = list(datasets)
317-
assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
326+
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
318327
for d in self.datasets:
319-
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
328+
assert not isinstance(
329+
d, IterableDataset
330+
), "ConcatDataset does not support IterableDataset"
320331
self.cumulative_sizes = self.cumsum(self.datasets)
321332

322333
def __len__(self):
@@ -325,7 +336,9 @@ def __len__(self):
325336
def __getitem__(self, idx):
326337
if idx < 0:
327338
if -idx > len(self):
328-
raise ValueError("absolute value of index should not exceed dataset length")
339+
raise ValueError(
340+
"absolute value of index should not exceed dataset length"
341+
)
329342
idx = len(self) + idx
330343
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
331344
if dataset_idx == 0:
@@ -336,8 +349,11 @@ def __getitem__(self, idx):
336349

337350
@property
338351
def cummulative_sizes(self):
339-
warnings.warn("cummulative_sizes attribute is renamed to "
340-
"cumulative_sizes", DeprecationWarning, stacklevel=2)
352+
warnings.warn(
353+
"cummulative_sizes attribute is renamed to " "cumulative_sizes",
354+
DeprecationWarning,
355+
stacklevel=2,
356+
)
341357
return self.cumulative_sizes
342358

343359

@@ -358,13 +374,17 @@ def __init__(self, datasets: Iterable[Dataset]) -> None:
358374

359375
def __iter__(self):
360376
for d in self.datasets:
361-
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
377+
assert isinstance(
378+
d, IterableDataset
379+
), "ChainDataset only supports IterableDataset"
362380
yield from d
363381

364382
def __len__(self):
365383
total = 0
366384
for d in self.datasets:
367-
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
385+
assert isinstance(
386+
d, IterableDataset
387+
), "ChainDataset only supports IterableDataset"
368388
total += len(d) # type: ignore[arg-type]
369389
return total
370390

@@ -402,8 +422,11 @@ def __len__(self):
402422
return len(self.indices)
403423

404424

405-
def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
406-
generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
425+
def random_split(
426+
dataset: Dataset[T],
427+
lengths: Sequence[Union[int, float]],
428+
generator: Optional[Generator] = default_generator,
429+
) -> List[Subset[T]]:
407430
r"""
408431
Randomly split a dataset into non-overlapping new datasets of given lengths.
409432
@@ -446,12 +469,20 @@ def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
446469
lengths = subset_lengths
447470
for i, length in enumerate(lengths):
448471
if length == 0:
449-
warnings.warn(f"Length of split at index {i} is 0. "
450-
f"This might result in an empty dataset.")
472+
warnings.warn(
473+
f"Length of split at index {i} is 0. "
474+
f"This might result in an empty dataset."
475+
)
451476

452477
# Cannot verify that dataset is Sized
453-
if sum(lengths) != len(dataset): # type: ignore[arg-type]
454-
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
478+
if sum(lengths) != len(dataset): # type: ignore[arg-type]
479+
raise ValueError(
480+
"Sum of input lengths does not equal the length of the input dataset!"
481+
)
455482

456483
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
457-
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
484+
lengths = cast(Sequence[int], lengths)
485+
return [
486+
Subset(dataset, indices[offset - length : offset])
487+
for offset, length in zip(itertools.accumulate(lengths), lengths)
488+
]

0 commit comments

Comments
 (0)