77import weakref
88from contextlib import suppress
99from timeit import default_timer
10+ from typing import Callable
1011
1112from tlz import valmap
1213from tornado .ioloop import IOLoop
@@ -243,7 +244,9 @@ def __init__(
243244 self ,
244245 keys ,
245246 scheduler = None ,
246- func = key_split ,
247+ * ,
248+ func = None ,
249+ group_by = "prefix" ,
247250 interval = "100ms" ,
248251 complete = False ,
249252 ** kwargs ,
@@ -256,8 +259,17 @@ def __init__(
256259 self .client = weakref .ref (key .client )
257260 break
258261
262+ if func is not None :
263+ warnings .warn (
264+ "`func` is deprecated, use `group_by` instead" ,
265+ category = DeprecationWarning ,
266+ )
267+ group_by = func
268+ elif group_by in (None , "prefix" ):
269+ group_by = key_split
270+
259271 self .keys = {k .key if hasattr (k , "key" ) else k for k in keys }
260- self .func = func
272+ self .group_by = group_by
261273 self .interval = interval
262274 self .complete = complete
263275 self ._start_time = default_timer ()
@@ -269,10 +281,15 @@ def elapsed(self):
269281 async def listen (self ):
270282 complete = self .complete
271283 keys = self .keys
272- func = self .func
284+ group_by = self .group_by
273285
274286 async def setup (scheduler ):
275- p = MultiProgress (keys , scheduler , complete = complete , func = func )
287+ p = MultiProgress (
288+ keys ,
289+ scheduler ,
290+ complete = complete ,
291+ group_by = group_by ,
292+ )
276293 await p .setup ()
277294 return p
278295
@@ -339,29 +356,31 @@ def __init__(
339356 keys ,
340357 scheduler = None ,
341358 minimum = 0 ,
342- interval = 0.1 ,
343- func = key_split ,
344- complete = False ,
345359 ** kwargs ,
346360 ):
347- super ().__init__ (keys , scheduler , func , interval , complete )
361+ super ().__init__ (keys , scheduler , ** kwargs )
348362 from ipywidgets import VBox
349363
350364 self .widget = VBox ([])
351365
352366 def make_widget (self , all ):
353367 from ipywidgets import HTML , FloatProgress , HBox , VBox
354368
369+ def make_label (key ):
370+ if isinstance (key , tuple ):
371+ # tuple of (group_name, group_id)
372+ key = key [0 ]
373+ key = key .decode () if isinstance (key , bytes ) else key
374+ return html .escape (key )
375+
355376 self .elapsed_time = HTML ("" )
356377 self .bars = {key : FloatProgress (min = 0 , max = 1 , description = "" ) for key in all }
357378 self .bar_texts = {key : HTML ("" ) for key in all }
358379 self .bar_labels = {
359380 key : HTML (
360381 '<div style="padding: 0px 10px 0px 10px;'
361382 " text-align:left; word-wrap: "
362- 'break-word;">'
363- + html .escape (key .decode () if isinstance (key , bytes ) else key )
364- + "</div>"
383+ 'break-word;">' + make_label (key ) + "</div>"
365384 )
366385 for key in all
367386 }
@@ -429,7 +448,9 @@ def _draw_bar(self, remaining, all, status, **kwargs):
429448 )
430449
431450
432- def progress (* futures , notebook = None , multi = True , complete = True , ** kwargs ):
451+ def progress (
452+ * futures , notebook = None , multi = True , complete = True , group_by = "prefix" , ** kwargs
453+ ):
433454 """Track progress of futures
434455
435456 This operates differently in the notebook and the console
@@ -448,6 +469,9 @@ def progress(*futures, notebook=None, multi=True, complete=True, **kwargs):
448469 complete : bool (optional)
449470 Track all keys (True) or only keys that have not yet run (False)
450471 (defaults to True)
472+ group_by : Callable | Literal["spans"] | Literal["prefix"]
473+ Use spans instead of task key names for grouping tasks
474+ (defaults to "prefix")
451475
452476 Notes
453477 -----
@@ -465,9 +489,18 @@ def progress(*futures, notebook=None, multi=True, complete=True, **kwargs):
465489 futures = [futures ]
466490 if notebook is None :
467491 notebook = is_kernel () # often but not always correct assumption
492+ if kwargs .get ("func" , None ) is not None :
493+ warnings .warn (
494+ "`func` is deprecated, use `group_by` instead" , category = DeprecationWarning
495+ )
496+ group_by = kwargs .pop ("func" )
497+ if group_by not in ("spans" , "prefix" ) and not isinstance (group_by , Callable ):
498+ raise ValueError ("`group_by` should be 'spans', 'prefix', or a Callable" )
468499 if notebook :
469500 if multi :
470- bar = MultiProgressWidget (futures , complete = complete , ** kwargs )
501+ bar = MultiProgressWidget (
502+ futures , complete = complete , group_by = group_by , ** kwargs
503+ )
471504 else :
472505 bar = ProgressWidget (futures , complete = complete , ** kwargs )
473506 return bar
0 commit comments