Skip to content

Commit 8b95f98

Browse files
author
Ian Rose
authored
Sparse array reductions (#9342)
1 parent 280ac97 commit 8b95f98

5 files changed

Lines changed: 160 additions & 113 deletions

File tree

dask/array/backends.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import math
2+
13
import numpy as np
24

5+
from dask.array import chunk
36
from dask.array.dispatch import (
47
concatenate_lookup,
58
divide_lookup,
69
einsum_lookup,
710
empty_lookup,
11+
nannumel_lookup,
12+
numel_lookup,
813
percentile_lookup,
914
tensordot_lookup,
1015
)
@@ -112,6 +117,8 @@ def _tensordot(a, b, axes=2):
112117

113118
@tensordot_lookup.register_lazy("cupy")
114119
@concatenate_lookup.register_lazy("cupy")
120+
@nannumel_lookup.register_lazy("cupy")
121+
@numel_lookup.register_lazy("cupy")
115122
def register_cupy():
116123
import cupy
117124

@@ -120,6 +127,8 @@ def register_cupy():
120127
concatenate_lookup.register(cupy.ndarray, cupy.concatenate)
121128
tensordot_lookup.register(cupy.ndarray, cupy.tensordot)
122129
percentile_lookup.register(cupy.ndarray, percentile)
130+
numel_lookup.register(cupy.ndarray, _numel_arraylike)
131+
nannumel_lookup.register(cupy.ndarray, _nannumel)
123132

124133
@einsum_lookup.register(cupy.ndarray)
125134
def _cupy_einsum(*args, **kwargs):
@@ -160,11 +169,18 @@ def _concat_cupy_sparse(L, axis=0):
160169

161170
@tensordot_lookup.register_lazy("sparse")
162171
@concatenate_lookup.register_lazy("sparse")
172+
@nannumel_lookup.register_lazy("sparse")
173+
@numel_lookup.register_lazy("sparse")
163174
def register_sparse():
164175
import sparse
165176

166177
concatenate_lookup.register(sparse.COO, sparse.concatenate)
167178
tensordot_lookup.register(sparse.COO, sparse.tensordot)
179+
# Enforce dense ndarray for the numel result, since the sparse
180+
# array will wind up being dense with an unpredictable fill_value.
181+
# https://github.com/dask/dask/issues/7169
182+
numel_lookup.register(sparse.COO, _numel_ndarray)
183+
nannumel_lookup.register(sparse.COO, _nannumel_sparse)
168184

169185

170186
@tensordot_lookup.register_lazy("scipy")
@@ -203,3 +219,80 @@ def _tensordot_scipy_sparse(a, b, axes):
203219
return a * b
204220
elif a_axis == 1 and b_axis == 1:
205221
return a * b.T
222+
223+
224+
@numel_lookup.register(np.ma.masked_array)
225+
def _numel_masked(x, **kwargs):
226+
"""Numel implementation for masked arrays."""
227+
return chunk.sum(np.ones_like(x), **kwargs)
228+
229+
230+
@numel_lookup.register((object, np.ndarray))
231+
def _numel_ndarray(x, **kwargs):
232+
"""Numel implementation for arrays that want to return numel of type ndarray."""
233+
return _numel(x, coerce_np_ndarray=True, **kwargs)
234+
235+
236+
def _numel_arraylike(x, **kwargs):
237+
"""Numel implementation for arrays that want to return numel of the same type."""
238+
return _numel(x, coerce_np_ndarray=False, **kwargs)
239+
240+
241+
def _numel(x, coerce_np_ndarray: bool, **kwargs):
242+
"""
243+
A reduction to count the number of elements.
244+
245+
This has an additional kwarg in coerce_np_ndarray, which determines
246+
whether to ensure that the resulting array is a numpy.ndarray, or whether
247+
we allow it to be other array types via `np.full_like`.
248+
"""
249+
shape = x.shape
250+
keepdims = kwargs.get("keepdims", False)
251+
axis = kwargs.get("axis", None)
252+
dtype = kwargs.get("dtype", np.float64)
253+
254+
if axis is None:
255+
prod = np.prod(shape, dtype=dtype)
256+
if keepdims is False:
257+
return prod
258+
259+
if coerce_np_ndarray:
260+
return np.full(shape=(1,) * len(shape), fill_value=prod, dtype=dtype)
261+
else:
262+
return np.full_like(x, prod, shape=(1,) * len(shape), dtype=dtype)
263+
264+
if not isinstance(axis, (tuple, list)):
265+
axis = [axis]
266+
267+
prod = math.prod(shape[dim] for dim in axis)
268+
if keepdims is True:
269+
new_shape = tuple(
270+
shape[dim] if dim not in axis else 1 for dim in range(len(shape))
271+
)
272+
else:
273+
new_shape = tuple(shape[dim] for dim in range(len(shape)) if dim not in axis)
274+
275+
if coerce_np_ndarray:
276+
return np.broadcast_to(np.array(prod, dtype=dtype), new_shape)
277+
else:
278+
return np.full_like(x, prod, shape=new_shape, dtype=dtype)
279+
280+
281+
@nannumel_lookup.register((object, np.ndarray))
282+
def _nannumel(x, **kwargs):
283+
"""A reduction to count the number of elements, excluding nans"""
284+
return chunk.sum(~(np.isnan(x)), **kwargs)
285+
286+
287+
def _nannumel_sparse(x, **kwargs):
288+
"""
289+
A reduction to count the number of elements in a sparse array, excluding nans.
290+
This will in general result in a dense matrix with an unpredictable fill value.
291+
So make it official and convert it to dense.
292+
293+
https://github.com/dask/dask/issues/7169
294+
"""
295+
n = _nannumel(x, **kwargs)
296+
# If all dimensions are contracted, this will just be a number, otherwise we
297+
# want to densify it.
298+
return n.todense() if hasattr(n, "todense") else n

dask/array/dispatch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
empty_lookup = Dispatch("empty")
1313
divide_lookup = Dispatch("divide")
1414
percentile_lookup = Dispatch("percentile")
15+
numel_lookup = Dispatch("numel")
16+
nannumel_lookup = Dispatch("nannumel")

dask/array/reductions.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
unknown_chunk_message,
2424
)
2525
from dask.array.creation import arange, diagonal
26-
27-
# Keep empty_lookup here for backwards compatibility
28-
from dask.array.dispatch import divide_lookup, empty_lookup # noqa: F401
26+
from dask.array.dispatch import divide_lookup, nannumel_lookup, numel_lookup
2927
from dask.array.utils import (
3028
array_safe,
3129
asarray_safe,
@@ -54,6 +52,14 @@ def divide(a, b, dtype=None):
5452
return f(a, b, dtype=dtype)
5553

5654

55+
def numel(x, **kwargs):
56+
return numel_lookup(x, **kwargs)
57+
58+
59+
def nannumel(x, **kwargs):
60+
return nannumel_lookup(x, **kwargs)
61+
62+
5763
def reduction(
5864
x,
5965
chunk,
@@ -638,43 +644,6 @@ def _nanmax_skip(x_chunk, axis, keepdims):
638644
)
639645

640646

641-
def numel(x, **kwargs):
642-
"""A reduction to count the number of elements"""
643-
644-
if hasattr(x, "mask"):
645-
return chunk.sum(np.ones_like(x), **kwargs)
646-
647-
shape = x.shape
648-
keepdims = kwargs.get("keepdims", False)
649-
axis = kwargs.get("axis", None)
650-
dtype = kwargs.get("dtype", np.float64)
651-
652-
if axis is None:
653-
prod = np.prod(shape, dtype=dtype)
654-
return (
655-
np.full_like(x, prod, shape=(1,) * len(shape), dtype=dtype)
656-
if keepdims is True
657-
else prod
658-
)
659-
660-
if not isinstance(axis, tuple or list):
661-
axis = [axis]
662-
663-
prod = math.prod(shape[dim] for dim in axis)
664-
if keepdims is True:
665-
new_shape = tuple(
666-
shape[dim] if dim not in axis else 1 for dim in range(len(shape))
667-
)
668-
else:
669-
new_shape = tuple(shape[dim] for dim in range(len(shape)) if dim not in axis)
670-
return np.full_like(x, prod, shape=new_shape, dtype=dtype)
671-
672-
673-
def nannumel(x, **kwargs):
674-
"""A reduction to count the number of elements"""
675-
return chunk.sum(~(np.isnan(x)), **kwargs)
676-
677-
678647
def mean_chunk(
679648
x, sum=chunk.sum, numel=numel, dtype="f8", computing_meta=False, **kwargs
680649
):

dask/array/tests/test_reductions.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,44 @@
1717

1818
@pytest.mark.parametrize("dtype", ["f4", "i4"])
1919
@pytest.mark.parametrize("keepdims", [True, False])
20-
def test_numel(dtype, keepdims):
20+
@pytest.mark.parametrize("nan", [True, False])
21+
def test_numel(dtype, keepdims, nan):
2122
x = np.ones((2, 3, 4))
23+
if nan:
24+
y = np.random.uniform(-1, 1, size=(2, 3, 4))
25+
x[y < 0] = np.nan
26+
numel = da.reductions.nannumel
27+
28+
def _sum(arr, **kwargs):
29+
n = np.sum(np.ma.masked_where(np.isnan(arr), arr), **kwargs)
30+
return n.filled(0) if isinstance(n, np.ma.MaskedArray) else n
31+
32+
else:
33+
numel = da.reductions.numel
34+
_sum = np.sum
2235

2336
assert_eq(
24-
da.reductions.numel(x, axis=(), keepdims=keepdims, dtype=dtype),
25-
np.sum(x, axis=(), keepdims=keepdims, dtype=dtype),
37+
numel(x, axis=(), keepdims=keepdims, dtype=dtype),
38+
_sum(x, axis=(), keepdims=keepdims, dtype=dtype),
2639
)
2740
assert_eq(
28-
da.reductions.numel(x, axis=0, keepdims=keepdims, dtype=dtype),
29-
np.sum(x, axis=0, keepdims=keepdims, dtype=dtype),
41+
numel(x, axis=0, keepdims=keepdims, dtype=dtype),
42+
_sum(x, axis=0, keepdims=keepdims, dtype=dtype),
3043
)
3144

3245
for length in range(x.ndim):
3346
for sub in itertools.combinations([d for d in range(x.ndim)], length):
3447
assert_eq(
35-
da.reductions.numel(x, axis=sub, keepdims=keepdims, dtype=dtype),
36-
np.sum(x, axis=sub, keepdims=keepdims, dtype=dtype),
48+
numel(x, axis=sub, keepdims=keepdims, dtype=dtype),
49+
_sum(x, axis=sub, keepdims=keepdims, dtype=dtype),
3750
)
3851

3952
for length in range(x.ndim):
4053
for sub in itertools.combinations([d for d in range(x.ndim)], length):
4154
ssub = np.random.shuffle(list(sub))
4255
assert_eq(
43-
da.reductions.numel(x, axis=ssub, keepdims=keepdims, dtype=dtype),
44-
np.sum(x, axis=ssub, keepdims=keepdims, dtype=dtype),
56+
numel(x, axis=ssub, keepdims=keepdims, dtype=dtype),
57+
_sum(x, axis=ssub, keepdims=keepdims, dtype=dtype),
4558
)
4659

4760

dask/array/tests/test_sparse.py

Lines changed: 34 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import random
2-
31
import numpy as np
42
import pytest
53
from packaging.version import parse as parse_version
64

75
import dask
86
import dask.array as da
7+
from dask.array.reductions import nannumel, numel
98
from dask.array.utils import assert_eq
109

1110
sparse = pytest.importorskip("sparse")
@@ -30,20 +29,26 @@
3029
lambda x: x[:1, None, 1:3],
3130
lambda x: x.T,
3231
lambda x: da.transpose(x, (1, 2, 0)),
32+
lambda x: da.nanmean(x),
33+
lambda x: da.nanmean(x, axis=1),
34+
lambda x: da.nanmax(x),
35+
lambda x: da.nanmin(x),
36+
lambda x: da.nanprod(x),
37+
lambda x: da.nanstd(x),
38+
lambda x: da.nanvar(x),
39+
lambda x: da.nansum(x),
40+
# These nan* variants are are not implemented by sparse.COO
41+
# lambda x: da.median(x, axis=0),
42+
# lambda x: da.nanargmax(x),
43+
# lambda x: da.nanargmin(x),
44+
# lambda x: da.nancumprod(x, axis=0),
45+
# lambda x: da.nancumsum(x, axis=0),
3346
lambda x: x.sum(),
3447
lambda x: x.moment(order=0),
35-
pytest.param(
36-
lambda x: x.mean(),
37-
marks=pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7169"),
38-
),
39-
pytest.param(
40-
lambda x: x.std(),
41-
marks=pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7169"),
42-
),
43-
pytest.param(
44-
lambda x: x.var(),
45-
marks=pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7169"),
46-
),
48+
lambda x: x.mean(),
49+
lambda x: x.mean(axis=1),
50+
lambda x: x.std(),
51+
lambda x: x.var(),
4752
lambda x: x.dot(np.arange(x.shape[-1])),
4853
lambda x: x.dot(np.eye(x.shape[-1])),
4954
lambda x: da.tensordot(x, np.ones(x.shape[:2]), axes=[(0, 1), (0, 1)]),
@@ -125,56 +130,6 @@ def test_tensordot():
125130
)
126131

127132

128-
@pytest.mark.xfail(reason="upstream change", strict=False)
129-
@pytest.mark.parametrize("func", functions)
130-
def test_mixed_concatenate(func):
131-
x = da.random.random((2, 3, 4), chunks=(1, 2, 2))
132-
133-
y = da.random.random((2, 3, 4), chunks=(1, 2, 2))
134-
y[y < 0.8] = 0
135-
yy = y.map_blocks(sparse.COO.from_numpy)
136-
137-
d = da.concatenate([x, y], axis=0)
138-
s = da.concatenate([x, yy], axis=0)
139-
140-
dd = func(d)
141-
ss = func(s)
142-
143-
assert_eq(dd, ss)
144-
145-
146-
@pytest.mark.xfail(reason="upstream change", strict=False)
147-
@pytest.mark.parametrize("func", functions)
148-
def test_mixed_random(func):
149-
d = da.random.random((4, 3, 4), chunks=(1, 2, 2))
150-
d[d < 0.7] = 0
151-
152-
fn = lambda x: sparse.COO.from_numpy(x) if random.random() < 0.5 else x
153-
s = d.map_blocks(fn)
154-
155-
dd = func(d)
156-
ss = func(s)
157-
158-
assert_eq(dd, ss)
159-
160-
161-
@pytest.mark.xfail(reason="upstream change", strict=False)
162-
def test_mixed_output_type():
163-
y = da.random.random((10, 10), chunks=(5, 5))
164-
y[y < 0.8] = 0
165-
y = y.map_blocks(sparse.COO.from_numpy)
166-
167-
x = da.zeros((10, 1), chunks=(5, 1))
168-
169-
z = da.concatenate([x, y], axis=1)
170-
171-
assert z.shape == (10, 11)
172-
173-
zz = z.compute()
174-
assert isinstance(zz, sparse.COO)
175-
assert zz.nnz == y.compute().nnz
176-
177-
178133
def test_metadata():
179134
y = da.random.random((10, 10), chunks=(5, 5))
180135
y[y < 0.8] = 0
@@ -239,3 +194,18 @@ def test_meta_from_array():
239194
x = sparse.COO.from_numpy(np.eye(1))
240195
y = da.utils.meta_from_array(x, ndim=2)
241196
assert isinstance(y, sparse.COO)
197+
198+
199+
@pytest.mark.parametrize("numel", [numel, nannumel])
200+
@pytest.mark.parametrize("axis", [0, (0, 1), None])
201+
@pytest.mark.parametrize("keepdims", [True, False])
202+
def test_numel(numel, axis, keepdims):
203+
x = np.random.random((2, 3, 4))
204+
x[x < 0.8] = 0
205+
x[x > 0.9] = np.nan
206+
207+
xs = sparse.COO.from_numpy(x, fill_value=0.0)
208+
209+
assert_eq(
210+
numel(x, axis=axis, keepdims=keepdims), numel(xs, axis=axis, keepdims=keepdims)
211+
)

0 commit comments

Comments
 (0)