@@ -259,38 +259,86 @@ def add5(a, b, c, d=0, e=0):
259259
260260
261261def test_manual_scatter (loop ):
262- w = CountSerialized (0 )
263- x = CountSerialized (1 )
264- y = CountSerialized (2 )
265- z = CountSerialized (3 )
262+ # Let's check that the number of times scattered and non-scattered
263+ # variables are serialized is consistent between `joblib.Parallel` calls
264+ # and equivalent native `client.submit` call.
265+
266+ # Number of serializations can vary from dask to another, so this test only
267+ # checks that `joblib.Parallel` does not add more serialization steps than
268+ # a native `client.submit` call, but does not check for an exact number of
269+ # serialization steps.
270+
271+ w , x , y , z = (CountSerialized (i ) for i in range (4 ))
272+
273+ f = delayed (add5 )
274+ tasks = [f (x , y , z , d = 4 , e = 5 ) for _ in range (10 )]
275+ tasks += [
276+ f (x , z , y , d = 5 , e = 4 ),
277+ f (y , x , z , d = x , e = 5 ),
278+ f (z , z , x , d = z , e = y ),
279+ ]
280+ expected = [func (* args , ** kwargs ) for func , args , kwargs in tasks ]
266281
267- with cluster () as (s , [ a , b ] ):
282+ with cluster () as (s , _ ):
268283 with Client (s ["address" ], loop = loop ) as client : # noqa: F841
269284 with parallel_config (backend = "dask" , scatter = [w , x , y ]):
270- f = delayed (add5 )
271- tasks = [f (x , y , z , d = 4 , e = 5 ) for _ in range (10 )]
272- tasks += [
273- f (x , z , y , d = 5 , e = 4 ),
274- f (y , x , z , d = x , e = 5 ),
275- f (z , z , x , d = z , e = y ),
276- ]
277- results = Parallel (batch_size = 1 )(tasks )
278-
279- # Scatter must take a list/tuple
285+ results_parallel = Parallel (batch_size = 1 )(tasks )
286+ assert results_parallel == expected
287+
288+ # Check that an error is raised for bad arguments, as scatter must
289+ # take a list/tuple
280290 with pytest .raises (TypeError ):
281291 with parallel_config (backend = "dask" , loop = loop , scatter = 1 ):
282292 pass
283293
284- expected = [func (* args , ** kwargs ) for func , args , kwargs in tasks ]
285- assert results == expected
286-
287294 # Scattered variables only serialized during scatter. Checking with an
288- # extra variable as this count can vary from one dask version to another.
289- n_serialization_scatter = w .count
290- assert x .count == n_serialization_scatter
291- assert y .count == n_serialization_scatter
292- # Should be serialized once per task
293- assert z .count == 13
295+ # extra variable as this count can vary from one dask version
296+ # to another.
297+ n_serialization_scatter_with_parallel = w .count
298+ assert x .count == n_serialization_scatter_with_parallel
299+ assert y .count == n_serialization_scatter_with_parallel
300+ n_serialization_with_parallel = z .count
301+
302+ # Reset the cluster and the serialization count
303+ for var in (w , x , y , z ):
304+ var .count = 0
305+
306+ with cluster () as (s , _ ):
307+ with Client (s ["address" ], loop = loop ) as client : # noqa: F841
308+ scattered = dict ()
309+ for obj in w , x , y :
310+ scattered [id (obj )] = client .scatter (obj , broadcast = True )
311+ results_native = [
312+ client .submit (
313+ func ,
314+ * (scattered .get (id (arg ), arg ) for arg in args ),
315+ ** dict (
316+ (key , scattered .get (id (value ), value ))
317+ for (key , value ) in kwargs .items ()
318+ ),
319+ key = str (uuid4 ()),
320+ ).result ()
321+ for (func , args , kwargs ) in tasks
322+ ]
323+ assert results_native == expected
324+
325+ # Now check that the number of serialization steps is the same for joblib
326+ # and native dask calls.
327+ n_serialization_scatter_native = w .count
328+ assert x .count == n_serialization_scatter_native
329+ assert y .count == n_serialization_scatter_native
330+
331+ assert n_serialization_scatter_with_parallel == n_serialization_scatter_native
332+
333+ distributed_version = tuple (int (v ) for v in distributed .__version__ .split ("." ))
334+ if distributed_version < (2023 , 4 ):
335+ # Previous to 2023.4, the serialization was adding an extra call to
336+ # __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z`
337+ # appears both in the args and kwargs, which is not the case when
338+ # running with joblib. Cope with this discrepancy.
339+ assert z .count == n_serialization_with_parallel + 1
340+ else :
341+ assert z .count == n_serialization_with_parallel
294342
295343
296344# When the same IOLoop is used for multiple clients in a row, use
0 commit comments