Skip to content

Dask.order causing high memory pressure for multi array compute calls (commonly used in xarray) #10384

@fjetter

Description

@fjetter

There is a simple xarray example that shows how memory pressure is building up due to improper dask ordering. See also pangeo-data/distributed-array-examples#2

import xarray as xr
import dask.array as da

ds = xr.Dataset(
    dict(
        anom_u=(["time", "face", "j", "i"], da.random.random((5000, 1, 987, 1920), chunks=(10, 1, -1, -1))),
        anom_v=(["time", "face", "j", "i"], da.random.random((5000, 1, 987, 1920), chunks=(10, 1, -1, -1))),
    )
)

quad = ds**2
quad["uv"] = ds.anom_u * ds.anom_v
mean = quad.mean("time")

This example generates about ~140GiB of random data but the mean at the end is reducing this to a couple KiB. This is done by using an ordinary tree reduction. That's easy and cheap and could almost be done on a raspberry pi. However, this thing blows up and it is almost impossible to run this on real data.

(Disclaimer: I'm not an xarray expert so if any drives by, ignore my ignorance if I'm saying something that is wrong)

The mean object above is a xarray.Dataset which is essentially a collection that holds multiple dask.arrays (in this example). When calling compute, all of those arrays are computed simultaneously. Effectively, this is similar to a dask.compute(array1, array2, array3)

In fact, when computing just a single array (e.g. mean['uv'].compute()) this is running wonderfully, is reducing the result immediately and no data generator task is held in memory for any significant amount of time.

However, when executing all of the above, the ordering breaks down and the random data gen tasks are held in memory for a very long time.

Running dask.order.diagnostics on this shows pressure (i.e. how long are certain tasks held in memory) of 267 on average with a maximum of 511. (compared to the single array reduction where we have 7 and 14 respectively).

The entire graph is too large to render but a scaled down version of it shows this (showing order-age) effect as well (albeit much smaller since the graph is smaller, of course)

quadratic_mean

(In the single-array tree reduction, the age of the data generators is somewhere between one and three)

I believe I was able to reduce this to a minimal dask.order example

from dask.base import visualize
a, b, c, d, e = list("abcde")
def f(*args):
    ...
dsk = {}
for ix in range(3):
    part = {
        # Part1
        (a, 0, ix): (f, ),
        (a, 1, ix): (f, ),
        (b, 0, ix): (f, (a, 0, ix)),
        (b, 1, ix): (f, (a, 0, ix), (a, 1, ix)),
        (b, 2, ix): (f, (a, 1, ix)),
        (c, 0, ix): (f, (b, 0, ix)),
        (c, 1, ix): (f, (b, 1, ix)),
        (c, 2, ix): (f, (b, 2, ix)),
    }
    dsk.update(part)
for ix in range(3):
    dsk.update({
        (d, ix): (f, (c, ix, 0), (c, ix, 1), (c, ix, 2)),
    })

raw-graph

A way to compute the result that generates a slightly differnet version of the graph that can be handled better by dask.order is possible by using mean.to_dask_dataframe(). The DataFrame version of this reduces also quite well.

cc @eriknw

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs attentionIt's been a while since this was pushed on. Needs attention from the owner or a maintainer.needs triageNeeds a response from a contributor

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions