11import bisect
2- import warnings
2+ import itertools
33import math
4+ import warnings
45from typing import (
6+ cast ,
7+ Dict ,
58 Generic ,
69 Iterable ,
710 List ,
1013 Tuple ,
1114 TypeVar ,
1215 Union ,
13- Dict
1416)
1517
1618# No 'default_generator' in torch/__init__.pyi
1719from torch import default_generator , randperm
18- from torch ._utils import _accumulate
1920
2021from ... import Generator , Tensor
2122
3031 "random_split" ,
3132]
3233
33- T_co = TypeVar (' T_co' , covariant = True )
34- T = TypeVar ('T' )
34+ T_co = TypeVar (" T_co" , covariant = True )
35+ T = TypeVar ("T" )
3536T_dict = Dict [str , T_co ]
3637T_tuple = Tuple [T_co , ...]
37- T_stack = TypeVar (' T_stack' , T_tuple , T_dict )
38+ T_stack = TypeVar (" T_stack" , T_tuple , T_dict )
3839
3940
4041class Dataset (Generic [T_co ]):
@@ -63,7 +64,7 @@ def __getitem__(self, index) -> T_co:
6364 # Not implemented to prevent false-positives in fetcher check in
6465 # torch.utils.data._utils.fetch._MapDatasetFetcher
6566
66- def __add__ (self , other : ' Dataset[T_co]' ) -> ' ConcatDataset[T_co]' :
67+ def __add__ (self , other : " Dataset[T_co]" ) -> " ConcatDataset[T_co]" :
6768 return ConcatDataset ([self , other ])
6869
6970 # No `def __len__(self)` default?
@@ -199,7 +200,9 @@ class TensorDataset(Dataset[Tuple[Tensor, ...]]):
199200 tensors : Tuple [Tensor , ...]
200201
201202 def __init__ (self , * tensors : Tensor ) -> None :
202- assert all (tensors [0 ].size (0 ) == tensor .size (0 ) for tensor in tensors ), "Size mismatch between tensors"
203+ assert all (
204+ tensors [0 ].size (0 ) == tensor .size (0 ) for tensor in tensors
205+ ), "Size mismatch between tensors"
203206 self .tensors = tensors
204207
205208 def __getitem__ (self , index ):
@@ -233,8 +236,10 @@ class StackDataset(Dataset[T_stack]):
233236 def __init__ (self , * args : Dataset [T_co ], ** kwargs : Dataset [T_co ]) -> None :
234237 if args :
235238 if kwargs :
236- raise ValueError ("Supported either ``tuple``- (via ``args``) or"
237- "``dict``- (via ``kwargs``) like input/output, but both types are given." )
239+ raise ValueError (
240+ "Supported either ``tuple``- (via ``args``) or"
241+ "``dict``- (via ``kwargs``) like input/output, but both types are given."
242+ )
238243 self ._length = len (args [0 ]) # type: ignore[arg-type]
239244 if any (self ._length != len (dataset ) for dataset in args ): # type: ignore[arg-type]
240245 raise ValueError ("Size mismatch between datasets" )
@@ -261,8 +266,10 @@ def __getitems__(self, indices: list):
261266 if callable (getattr (dataset , "__getitems__" , None )):
262267 items = dataset .__getitems__ (indices ) # type: ignore[attr-defined]
263268 if len (items ) != len (indices ):
264- raise ValueError ("Nested dataset's output size mismatch."
265- f" Expected { len (indices )} , got { len (items )} " )
269+ raise ValueError (
270+ "Nested dataset's output size mismatch."
271+ f" Expected { len (indices )} , got { len (items )} "
272+ )
266273 for data , d_sample in zip (items , dict_batch ):
267274 d_sample [k ] = data
268275 else :
@@ -276,8 +283,10 @@ def __getitems__(self, indices: list):
276283 if callable (getattr (dataset , "__getitems__" , None )):
277284 items = dataset .__getitems__ (indices ) # type: ignore[attr-defined]
278285 if len (items ) != len (indices ):
279- raise ValueError ("Nested dataset's output size mismatch."
280- f" Expected { len (indices )} , got { len (items )} " )
286+ raise ValueError (
287+ "Nested dataset's output size mismatch."
288+ f" Expected { len (indices )} , got { len (items )} "
289+ )
281290 for data , t_sample in zip (items , list_batch ):
282291 t_sample .append (data )
283292 else :
@@ -314,9 +323,11 @@ def cumsum(sequence):
314323 def __init__ (self , datasets : Iterable [Dataset ]) -> None :
315324 super ().__init__ ()
316325 self .datasets = list (datasets )
317- assert len (self .datasets ) > 0 , ' datasets should not be an empty iterable' # type: ignore[arg-type]
326+ assert len (self .datasets ) > 0 , " datasets should not be an empty iterable" # type: ignore[arg-type]
318327 for d in self .datasets :
319- assert not isinstance (d , IterableDataset ), "ConcatDataset does not support IterableDataset"
328+ assert not isinstance (
329+ d , IterableDataset
330+ ), "ConcatDataset does not support IterableDataset"
320331 self .cumulative_sizes = self .cumsum (self .datasets )
321332
322333 def __len__ (self ):
@@ -325,7 +336,9 @@ def __len__(self):
325336 def __getitem__ (self , idx ):
326337 if idx < 0 :
327338 if - idx > len (self ):
328- raise ValueError ("absolute value of index should not exceed dataset length" )
339+ raise ValueError (
340+ "absolute value of index should not exceed dataset length"
341+ )
329342 idx = len (self ) + idx
330343 dataset_idx = bisect .bisect_right (self .cumulative_sizes , idx )
331344 if dataset_idx == 0 :
@@ -336,8 +349,11 @@ def __getitem__(self, idx):
336349
337350 @property
338351 def cummulative_sizes (self ):
339- warnings .warn ("cummulative_sizes attribute is renamed to "
340- "cumulative_sizes" , DeprecationWarning , stacklevel = 2 )
352+ warnings .warn (
353+ "cummulative_sizes attribute is renamed to " "cumulative_sizes" ,
354+ DeprecationWarning ,
355+ stacklevel = 2 ,
356+ )
341357 return self .cumulative_sizes
342358
343359
@@ -358,13 +374,17 @@ def __init__(self, datasets: Iterable[Dataset]) -> None:
358374
359375 def __iter__ (self ):
360376 for d in self .datasets :
361- assert isinstance (d , IterableDataset ), "ChainDataset only supports IterableDataset"
377+ assert isinstance (
378+ d , IterableDataset
379+ ), "ChainDataset only supports IterableDataset"
362380 yield from d
363381
364382 def __len__ (self ):
365383 total = 0
366384 for d in self .datasets :
367- assert isinstance (d , IterableDataset ), "ChainDataset only supports IterableDataset"
385+ assert isinstance (
386+ d , IterableDataset
387+ ), "ChainDataset only supports IterableDataset"
368388 total += len (d ) # type: ignore[arg-type]
369389 return total
370390
@@ -402,8 +422,11 @@ def __len__(self):
402422 return len (self .indices )
403423
404424
405- def random_split (dataset : Dataset [T ], lengths : Sequence [Union [int , float ]],
406- generator : Optional [Generator ] = default_generator ) -> List [Subset [T ]]:
425+ def random_split (
426+ dataset : Dataset [T ],
427+ lengths : Sequence [Union [int , float ]],
428+ generator : Optional [Generator ] = default_generator ,
429+ ) -> List [Subset [T ]]:
407430 r"""
408431 Randomly split a dataset into non-overlapping new datasets of given lengths.
409432
@@ -446,12 +469,20 @@ def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
446469 lengths = subset_lengths
447470 for i , length in enumerate (lengths ):
448471 if length == 0 :
449- warnings .warn (f"Length of split at index { i } is 0. "
450- f"This might result in an empty dataset." )
472+ warnings .warn (
473+ f"Length of split at index { i } is 0. "
474+ f"This might result in an empty dataset."
475+ )
451476
452477 # Cannot verify that dataset is Sized
453- if sum (lengths ) != len (dataset ): # type: ignore[arg-type]
454- raise ValueError ("Sum of input lengths does not equal the length of the input dataset!" )
478+ if sum (lengths ) != len (dataset ): # type: ignore[arg-type]
479+ raise ValueError (
480+ "Sum of input lengths does not equal the length of the input dataset!"
481+ )
455482
456483 indices = randperm (sum (lengths ), generator = generator ).tolist () # type: ignore[arg-type, call-overload]
457- return [Subset (dataset , indices [offset - length : offset ]) for offset , length in zip (_accumulate (lengths ), lengths )]
484+ lengths = cast (Sequence [int ], lengths )
485+ return [
486+ Subset (dataset , indices [offset - length : offset ])
487+ for offset , length in zip (itertools .accumulate (lengths ), lengths )
488+ ]
0 commit comments