Skip to content

Commit df4403f

Browse files
fcharrastomMoral
andauthored
FIX update test_manual_scatter to check consistence between dask backend and native dask (#1718)
Co-authored-by: tommoral <thomas.moreau.2010@gmail.com>
1 parent 8515638 commit df4403f

1 file changed

Lines changed: 72 additions & 24 deletions

File tree

joblib/test/test_dask.py

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -259,38 +259,86 @@ def add5(a, b, c, d=0, e=0):
259259

260260

261261
def test_manual_scatter(loop):
262-
w = CountSerialized(0)
263-
x = CountSerialized(1)
264-
y = CountSerialized(2)
265-
z = CountSerialized(3)
262+
# Let's check that the number of times scattered and non-scattered
263+
# variables are serialized is consistent between `joblib.Parallel` calls
264+
# and equivalent native `client.submit` call.
265+
266+
# Number of serializations can vary from dask to another, so this test only
267+
# checks that `joblib.Parallel` does not add more serialization steps than
268+
# a native `client.submit` call, but does not check for an exact number of
269+
# serialization steps.
270+
271+
w, x, y, z = (CountSerialized(i) for i in range(4))
272+
273+
f = delayed(add5)
274+
tasks = [f(x, y, z, d=4, e=5) for _ in range(10)]
275+
tasks += [
276+
f(x, z, y, d=5, e=4),
277+
f(y, x, z, d=x, e=5),
278+
f(z, z, x, d=z, e=y),
279+
]
280+
expected = [func(*args, **kwargs) for func, args, kwargs in tasks]
266281

267-
with cluster() as (s, [a, b]):
282+
with cluster() as (s, _):
268283
with Client(s["address"], loop=loop) as client: # noqa: F841
269284
with parallel_config(backend="dask", scatter=[w, x, y]):
270-
f = delayed(add5)
271-
tasks = [f(x, y, z, d=4, e=5) for _ in range(10)]
272-
tasks += [
273-
f(x, z, y, d=5, e=4),
274-
f(y, x, z, d=x, e=5),
275-
f(z, z, x, d=z, e=y),
276-
]
277-
results = Parallel(batch_size=1)(tasks)
278-
279-
# Scatter must take a list/tuple
285+
results_parallel = Parallel(batch_size=1)(tasks)
286+
assert results_parallel == expected
287+
288+
# Check that an error is raised for bad arguments, as scatter must
289+
# take a list/tuple
280290
with pytest.raises(TypeError):
281291
with parallel_config(backend="dask", loop=loop, scatter=1):
282292
pass
283293

284-
expected = [func(*args, **kwargs) for func, args, kwargs in tasks]
285-
assert results == expected
286-
287294
# Scattered variables only serialized during scatter. Checking with an
288-
# extra variable as this count can vary from one dask version to another.
289-
n_serialization_scatter = w.count
290-
assert x.count == n_serialization_scatter
291-
assert y.count == n_serialization_scatter
292-
# Should be serialized once per task
293-
assert z.count == 13
295+
# extra variable as this count can vary from one dask version
296+
# to another.
297+
n_serialization_scatter_with_parallel = w.count
298+
assert x.count == n_serialization_scatter_with_parallel
299+
assert y.count == n_serialization_scatter_with_parallel
300+
n_serialization_with_parallel = z.count
301+
302+
# Reset the cluster and the serialization count
303+
for var in (w, x, y, z):
304+
var.count = 0
305+
306+
with cluster() as (s, _):
307+
with Client(s["address"], loop=loop) as client: # noqa: F841
308+
scattered = dict()
309+
for obj in w, x, y:
310+
scattered[id(obj)] = client.scatter(obj, broadcast=True)
311+
results_native = [
312+
client.submit(
313+
func,
314+
*(scattered.get(id(arg), arg) for arg in args),
315+
**dict(
316+
(key, scattered.get(id(value), value))
317+
for (key, value) in kwargs.items()
318+
),
319+
key=str(uuid4()),
320+
).result()
321+
for (func, args, kwargs) in tasks
322+
]
323+
assert results_native == expected
324+
325+
# Now check that the number of serialization steps is the same for joblib
326+
# and native dask calls.
327+
n_serialization_scatter_native = w.count
328+
assert x.count == n_serialization_scatter_native
329+
assert y.count == n_serialization_scatter_native
330+
331+
assert n_serialization_scatter_with_parallel == n_serialization_scatter_native
332+
333+
distributed_version = tuple(int(v) for v in distributed.__version__.split("."))
334+
if distributed_version < (2023, 4):
335+
# Previous to 2023.4, the serialization was adding an extra call to
336+
# __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z`
337+
# appears both in the args and kwargs, which is not the case when
338+
# running with joblib. Cope with this discrepancy.
339+
assert z.count == n_serialization_with_parallel + 1
340+
else:
341+
assert z.count == n_serialization_with_parallel
294342

295343

296344
# When the same IOLoop is used for multiple clients in a row, use

0 commit comments

Comments
 (0)