@@ -401,7 +401,7 @@ def safe_sqr(X, copy=True):
401401 return X
402402
403403
404- def gen_batches (n , batch_size ):
404+ def gen_batches (n , batch_size , min_batch_size = 0 ):
405405 """Generator to create slices containing batch_size elements, from 0 to n.
406406
407407 The last slice may contain less than batch_size elements, when batch_size
@@ -412,6 +412,8 @@ def gen_batches(n, batch_size):
412412 n : int
413413 batch_size : int
414414 Number of element in each batch
415+ min_batch_size : int, default=0
416+ Minimum batch size to produce.
415417
416418 Yields
417419 ------
@@ -426,10 +428,16 @@ def gen_batches(n, batch_size):
426428 [slice(0, 3, None), slice(3, 6, None)]
427429 >>> list(gen_batches(2, 3))
428430 [slice(0, 2, None)]
431+ >>> list(gen_batches(7, 3, min_batch_size=0))
432+ [slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
433+ >>> list(gen_batches(7, 3, min_batch_size=2))
434+ [slice(0, 3, None), slice(3, 7, None)]
429435 """
430436 start = 0
431437 for _ in range (int (n // batch_size )):
432438 end = start + batch_size
439+ if end + min_batch_size > n :
440+ continue
433441 yield slice (start , end )
434442 start = end
435443 if start < n :
0 commit comments