-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Dask.order causing high memory pressure for multi array compute calls (commonly used in xarray) #10384
Description
There is a simple xarray example that shows how memory pressure is building up due to improper dask ordering. See also pangeo-data/distributed-array-examples#2
import xarray as xr
import dask.array as da
ds = xr.Dataset(
dict(
anom_u=(["time", "face", "j", "i"], da.random.random((5000, 1, 987, 1920), chunks=(10, 1, -1, -1))),
anom_v=(["time", "face", "j", "i"], da.random.random((5000, 1, 987, 1920), chunks=(10, 1, -1, -1))),
)
)
quad = ds**2
quad["uv"] = ds.anom_u * ds.anom_v
mean = quad.mean("time")This example generates about ~140GiB of random data but the mean at the end is reducing this to a couple KiB. This is done by using an ordinary tree reduction. That's easy and cheap and could almost be done on a raspberry pi. However, this thing blows up and it is almost impossible to run this on real data.
(Disclaimer: I'm not an xarray expert so if any drives by, ignore my ignorance if I'm saying something that is wrong)
The mean object above is a xarray.Dataset which is essentially a collection that holds multiple dask.arrays (in this example). When calling compute, all of those arrays are computed simultaneously. Effectively, this is similar to a dask.compute(array1, array2, array3)
In fact, when computing just a single array (e.g. mean['uv'].compute()) this is running wonderfully, is reducing the result immediately and no data generator task is held in memory for any significant amount of time.
However, when executing all of the above, the ordering breaks down and the random data gen tasks are held in memory for a very long time.
Running dask.order.diagnostics on this shows pressure (i.e. how long are certain tasks held in memory) of 267 on average with a maximum of 511. (compared to the single array reduction where we have 7 and 14 respectively).
The entire graph is too large to render but a scaled down version of it shows this (showing order-age) effect as well (albeit much smaller since the graph is smaller, of course)
(In the single-array tree reduction, the age of the data generators is somewhere between one and three)
I believe I was able to reduce this to a minimal dask.order example
from dask.base import visualize
a, b, c, d, e = list("abcde")
def f(*args):
...
dsk = {}
for ix in range(3):
part = {
# Part1
(a, 0, ix): (f, ),
(a, 1, ix): (f, ),
(b, 0, ix): (f, (a, 0, ix)),
(b, 1, ix): (f, (a, 0, ix), (a, 1, ix)),
(b, 2, ix): (f, (a, 1, ix)),
(c, 0, ix): (f, (b, 0, ix)),
(c, 1, ix): (f, (b, 1, ix)),
(c, 2, ix): (f, (b, 2, ix)),
}
dsk.update(part)
for ix in range(3):
dsk.update({
(d, ix): (f, (c, ix, 0), (c, ix, 1), (c, ix, 2)),
})A way to compute the result that generates a slightly differnet version of the graph that can be handled better by dask.order is possible by using mean.to_dask_dataframe(). The DataFrame version of this reduces also quite well.
cc @eriknw

