Skip to content

Tune dask.array.optimize #1133

@mrocklin

Description

@mrocklin

In some pathological (yet also real) cases, graph optimization can cost about as much as computation:

I've been benchmarking against this computation

import dask.array as da
n = 12
x = da.random.random((2**n, 2**n), chunks=(2**n, 1))   # 4096 x 4096, 4096 x 1
for i in range(1, n):                                  # 4096 x 4096, 2048 x 2
        x = x.rechunk((2**(n - i), 2**(i + 1)))            # 4096 x 4096, 1024 x 4

y = x.sum()

>>> %time y.compute()
CPU times: user 39 s, sys: 2.45 s, total: 41.5 s
Wall time: 39.1 s

>>> %prun  y._optimize(y.dask, y._keys())

         29816630 function calls (28569410 primitive calls) in 22.570 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   604153    3.733    0.000    8.800    0.000 core.py:189(get_dependencies)
  6240188    2.504    0.000    2.843    0.000 core.py:27(istask)
  1312754    2.414    0.000    3.091    0.000 core.py:154(_deps)
        3    2.146    0.715    2.146    0.715 core.py:284(<listcomp>)
   299004    1.456    0.000    3.326    0.000 rewrite.py:375(_match)
909302/77823    1.331    0.000    2.017    0.000 optimize.py:211(functions_of)
  2009082    0.625    0.000    0.625    0.000 {built-in method isinstance}
        3    0.618    0.206    2.821    0.940 core.py:275(reverse_dict)
299004/77823    0.553    0.000    5.650    0.000 rewrite.py:363(_bottom_up)
   299004    0.480    0.000    4.206    0.000 rewrite.py:283(iter_matches)
   348156    0.426    0.000    0.426    0.000 rewrite.py:50(__init__)
   874487    0.406    0.000    1.014    0.000 rewrite.py:8(head)
  2303972    0.403    0.000    0.403    0.000 {method 'pop' of 'list' objects}
219136/51200    0.400    0.000    0.990    0.000 core.py:291(subs)
  2250724    0.391    0.000    0.391    0.000 {method 'extend' of 'list' objects}
        2    0.376    0.188    4.363    2.182 optimize.py:117(inline)
    49152    0.341    0.000    0.673    0.000 core.py:307(<listcomp>)
  4671438    0.339    0.000    0.339    0.000 {built-in method callable}
   299004    0.270    0.000    4.477    0.000 rewrite.py:304(_rewrite)
   794616    0.257    0.000    1.211    0.000 rewrite.py:81(current)
        1    0.204    0.204    1.261    1.261 optimize.py:43(fuse)
   745465    0.186    0.000    0.186    0.000 {method 'get' of 'dict' objects}
  1484786    0.168    0.000    0.168    0.000 {method 'append' of 'list' objects}
        2    0.133    0.066    0.752    0.376 core.py:321(_toposort)
        1    0.127    0.127    1.249    1.249 optimize.py:16(cull)
   348156    0.126    0.000    0.126    0.000 {method 'pop' of 'collections.deque' objects}
    79871    0.125    0.000    0.185    0.000 optimization.py:32(is_full_slice)
   745464    0.122    0.000    0.122    0.000 rewrite.py:119(edges)
        1    0.122    0.122    2.202    2.202 optimize.py:200(<listcomp>)
   595958    0.113    0.000    0.113    0.000 {method 'add' of 'set' objects}
        1    0.106    0.106    5.805    5.805 {method 'update' of 'dict' objects}
225278/221182    0.106    0.000    3.578    0.000 rewrite.py:365(<genexpr>)
        1    0.098    0.098    6.363    6.363 optimize.py:237(dealias)
   176126    0.093    0.000    0.175    0.000 rewrite.py:19(args)
    81919    0.086    0.000    0.086    0.000 {built-in method hasattr}
    49152    0.070    0.000    0.135    0.000 rewrite.py:70(next)
    79871    0.069    0.000    0.069    0.000 {built-in method hash}
47103/24575    0.067    0.000    1.377    0.000 rewrite.py:367(<listcomp>)
    79872    0.066    0.000    1.486    0.000 optimize.py:287(<genexpr>)
    79872    0.065    0.000    1.344    0.000 optimize.py:278(<genexpr>)
    79871    0.063    0.000    0.132    0.000 core.py:9(ishashable)
    77824    0.061    0.000    1.311    0.000 optimize.py:197(<genexpr>)
    49152    0.058    0.000    0.084    0.000 rewrite.py:62(copy)
        1    0.058    0.058    6.676    6.676 optimization.py:49(remove_full_slices)
        1    0.054    0.054   22.570   22.570 <string>:1(<module>)
    77823    0.049    0.000    5.699    0.000 rewrite.py:315(rewrite)
    79872    0.044    0.000    0.044    0.000 optimize.py:40(<genexpr>)
    28675    0.042    0.000    0.042    0.000 {method 'values' of 'dict' objects}
        1    0.039    0.039    7.491    7.491 optimize.py:168(inline_functions)
    98304    0.038    0.000    0.038    0.000 optimization.py:46(<genexpr>)
        1    0.037    0.037    0.169    0.169 optimize.py:281(<genexpr>)
        1    0.036    0.036    0.221    0.221 optimization.py:69(<genexpr>)
    79872    0.033    0.000    0.033    0.000 optimization.py:71(<genexpr>)
        1    0.033    0.033   22.515   22.515 optimization.py:14(optimize)
    49152    0.032    0.000    0.780    0.000 core.py:314(<listcomp>)
    79871    0.031    0.000    0.117    0.000 optimize.py:231(unwrap_partial)
    49154    0.026    0.000    0.026    0.000 optimize.py:145(<genexpr>)
    49160    0.018    0.000    0.028    0.000 core.py:246(flatten)
        2    0.011    0.006    0.011    0.006 optimize.py:282(<genexpr>)
    49152    0.011    0.000    0.011    0.000 {method 'extend' of 'collections.deque' objects}
        1    0.009    0.009    0.009    0.009 {method 'copy' of 'dict' objects}
    49152    0.009    0.000    0.037    0.000 {built-in method all}
    77823    0.008    0.000    0.008    0.000 {method 'issubset' of 'set' objects}
    79872    0.008    0.000    0.008    0.000 {built-in method len}

I also recommend looking at this with snakeviz

%load_ext snakeviz
%snakeviz y._optimize(y.dask, y._keys())

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions