Skip to content

Fuse array elementwise operations at graph build time #2538

@jcrist

Description

@jcrist

Linear array expressions in dask array sustain some overhead that may be optimized away, as noted in #2034, #2497.

  • Each operation (e.g. __add__) adds another set of tasks to the graph. This adds overhead both in time (cost of building the graph) and space (overhead of larger dictionary representing the graph).

  • The larger graph increases overhead of our optimization passes for fusing tasks. It'd be better to generate a better graph rather than relying on optimization passes to detect and fuse elementwise operations.

Here we compute the same elementwise operation in 2 ways:

  • First using operators directly on dask arrays. This generates individual tasks for each operator, which are then fused together during our optimization passes.
  • Second using map_blocks and a function representing the whole elementwise operation. This generates a single task for each block (much smaller graph) and requires no optimization.
In [1]: import numpy as np, dask.array as da

In [2]: x = np.random.normal(size=int(2e8))

In [3]: y = np.random.normal(size=int(2e8))

In [4]: dx = da.from_array(x, chunks=int(1e6))

In [5]: dy = da.from_array(y, chunks=int(1e6))

In [6]: def f(x, y):
   ...:     return (0.5 - x)**2 + 0.8 * (y - x**2)**2
   ...:

In [7]: %%time
   ...: o = f(dx, dy).max()
   ...: print("%d tasks" % len(o.dask))
   ...: _ = o.compute()
   ...:
2070 tasks
CPU times: user 3.34 s, sys: 1.47 s, total: 4.81 s
Wall time: 1.76 s

In [8]: %%time
   ...: o = da.map_blocks(f, dx, dy).max()
   ...: print("%d tasks" % len(o.dask))
   ...: _ = o.compute()
   ...:
870 tasks
CPU times: user 3.06 s, sys: 1.36 s, total: 4.42 s
Wall time: 1.53 s

From the above you can see using map_blocks on a single function results in a smaller graph and faster execution time. For larger arrays/array expressions the benefits are even larger.

It'd be nice to be able to write code like the first example (using operators) and have the graph be equivalent to the second (single mapped function). In #1946 a solution using task fusion was proposed. This is more general, but wouldn't reduce the size of the generated graph until fusion time, which still incurs the cost of generating and optimizing the graph.

Instead we propose avoiding generating the large graphs in the first place by adding a simple expression system to dask.array. This would only cover the linear operations generated by da.atop (which backs a good number of functions/operators). By changing atop to return a special object that encodes the meaning of the atop operation but doesn't generate the graph yet. Further calls to atop using that output would continue to build up expressions until an incompatible method was called. The expression structure could either be stored in a custom MutableMapping (as suggested in #1763) or in a thin subclass of da.Array.

Example

x = some_dask_array()
# No graph generated yet, just expression encoding `x + 1`
x2 = x + 1
x3 = x2 + x
x4 = da.sin(x3)
# `sum` isn't a linear expression, so graph for `x4` is generated
# and (potentially) cached in `x4`. This graph maps the function
# `sin((x + 1) + x)` across all blocks
x5 = x4.sum()

One further benefit of the simple expression system means that we have enough information that we could optionally optimize the evaluation of tasks, either by interpreting the expressions to use out= keywords to the numpy methods to reduce memory overhead, or by calling nb.jit on expressions. If done this should either be explicit or configurable, but would yield further performance improvements.

In [9]: import numba as nb

In [10]: f2 = nb.jit(f, nopython=True, nogil=True)

In [11]: f2.compile('f8,f8') # precompile
Out[11]: <function __main__._Closure.f>

In [12]: %%time
    ...: o = da.map_blocks(f2, dx, dy).max()
    ...: print("%d tasks" % len(o.dask))
    ...: _ = o.compute()
    ...:
    ...:
870 tasks
CPU times: user 2.1 s, sys: 285 ms, total: 2.39 s
Wall time: 613 ms

As a further extension, this could potentially solve #2431, as __getitem__ on the thin subclass of da.Array (da.AtopArray?) could be overridden to push slices back before the elementwise operations. This would be more robust than relying on an optimization pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions