-
Notifications
You must be signed in to change notification settings - Fork 32.5k
Closed
Description
Hi,
get_length_grouped_indices() in LengthGroupedSampler and DistributedLengthGroupedSampler is prohibitively slow for large number of megabatches (in my case takes hours for ~270k megabatches with 100 items each) due to slow list concatenation with sum(megabatches, []).
Concatenating the lists with sum() may be repeatedly reallocating memory with each successive concatenation (similar to performance issues with string concatenation).
[item for sublist in megabatches for item in sublist] approach appears to significantly improve speed for large megabatch number, especially for megabatches with larger number of items.
For example:
# 50,000 megabatches with 3 items each:
megabatches = [[1,2,3] for _ in range(50_000)]
%timeit [item for sublist in megabatches for item in sublist];
3.72 ms ± 75.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit sum(megabatches, []);
7.66 s ± 31.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 100,000 megabatches with 3 items each:
megabatches = [[1,2,3] for _ in range(100_000)]
%timeit [item for sublist in megabatches for item in sublist];
8.03 ms ± 14.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit sum(megabatches, []);
29.6 s ± 36.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 100,000 megabatches with 100 items each:
megabatches = [list(range(100)) for _ in range(100_000)]
%timeit [item for sublist in megabatches for item in sublist];
208 ms ± 44.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit -r1 -n1 sum(megabatches, []);
44min 3s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Thank you for your wonderful work and consideration of this edit. @sgugger
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels