|
1 | | -import random |
2 | | - |
3 | 1 | import numpy as np |
4 | 2 | import pytest |
5 | 3 | from packaging.version import parse as parse_version |
6 | 4 |
|
7 | 5 | import dask |
8 | 6 | import dask.array as da |
| 7 | +from dask.array.reductions import nannumel, numel |
9 | 8 | from dask.array.utils import assert_eq |
10 | 9 |
|
11 | 10 | sparse = pytest.importorskip("sparse") |
|
30 | 29 | lambda x: x[:1, None, 1:3], |
31 | 30 | lambda x: x.T, |
32 | 31 | lambda x: da.transpose(x, (1, 2, 0)), |
| 32 | + lambda x: da.nanmean(x), |
| 33 | + lambda x: da.nanmean(x, axis=1), |
| 34 | + lambda x: da.nanmax(x), |
| 35 | + lambda x: da.nanmin(x), |
| 36 | + lambda x: da.nanprod(x), |
| 37 | + lambda x: da.nanstd(x), |
| 38 | + lambda x: da.nanvar(x), |
| 39 | + lambda x: da.nansum(x), |
| 40 | + # These nan* variants are are not implemented by sparse.COO |
| 41 | + # lambda x: da.median(x, axis=0), |
| 42 | + # lambda x: da.nanargmax(x), |
| 43 | + # lambda x: da.nanargmin(x), |
| 44 | + # lambda x: da.nancumprod(x, axis=0), |
| 45 | + # lambda x: da.nancumsum(x, axis=0), |
33 | 46 | lambda x: x.sum(), |
34 | 47 | lambda x: x.moment(order=0), |
35 | | - pytest.param( |
36 | | - lambda x: x.mean(), |
37 | | - marks=pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7169"), |
38 | | - ), |
39 | | - pytest.param( |
40 | | - lambda x: x.std(), |
41 | | - marks=pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7169"), |
42 | | - ), |
43 | | - pytest.param( |
44 | | - lambda x: x.var(), |
45 | | - marks=pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7169"), |
46 | | - ), |
| 48 | + lambda x: x.mean(), |
| 49 | + lambda x: x.mean(axis=1), |
| 50 | + lambda x: x.std(), |
| 51 | + lambda x: x.var(), |
47 | 52 | lambda x: x.dot(np.arange(x.shape[-1])), |
48 | 53 | lambda x: x.dot(np.eye(x.shape[-1])), |
49 | 54 | lambda x: da.tensordot(x, np.ones(x.shape[:2]), axes=[(0, 1), (0, 1)]), |
@@ -125,56 +130,6 @@ def test_tensordot(): |
125 | 130 | ) |
126 | 131 |
|
127 | 132 |
|
128 | | -@pytest.mark.xfail(reason="upstream change", strict=False) |
129 | | -@pytest.mark.parametrize("func", functions) |
130 | | -def test_mixed_concatenate(func): |
131 | | - x = da.random.random((2, 3, 4), chunks=(1, 2, 2)) |
132 | | - |
133 | | - y = da.random.random((2, 3, 4), chunks=(1, 2, 2)) |
134 | | - y[y < 0.8] = 0 |
135 | | - yy = y.map_blocks(sparse.COO.from_numpy) |
136 | | - |
137 | | - d = da.concatenate([x, y], axis=0) |
138 | | - s = da.concatenate([x, yy], axis=0) |
139 | | - |
140 | | - dd = func(d) |
141 | | - ss = func(s) |
142 | | - |
143 | | - assert_eq(dd, ss) |
144 | | - |
145 | | - |
146 | | -@pytest.mark.xfail(reason="upstream change", strict=False) |
147 | | -@pytest.mark.parametrize("func", functions) |
148 | | -def test_mixed_random(func): |
149 | | - d = da.random.random((4, 3, 4), chunks=(1, 2, 2)) |
150 | | - d[d < 0.7] = 0 |
151 | | - |
152 | | - fn = lambda x: sparse.COO.from_numpy(x) if random.random() < 0.5 else x |
153 | | - s = d.map_blocks(fn) |
154 | | - |
155 | | - dd = func(d) |
156 | | - ss = func(s) |
157 | | - |
158 | | - assert_eq(dd, ss) |
159 | | - |
160 | | - |
161 | | -@pytest.mark.xfail(reason="upstream change", strict=False) |
162 | | -def test_mixed_output_type(): |
163 | | - y = da.random.random((10, 10), chunks=(5, 5)) |
164 | | - y[y < 0.8] = 0 |
165 | | - y = y.map_blocks(sparse.COO.from_numpy) |
166 | | - |
167 | | - x = da.zeros((10, 1), chunks=(5, 1)) |
168 | | - |
169 | | - z = da.concatenate([x, y], axis=1) |
170 | | - |
171 | | - assert z.shape == (10, 11) |
172 | | - |
173 | | - zz = z.compute() |
174 | | - assert isinstance(zz, sparse.COO) |
175 | | - assert zz.nnz == y.compute().nnz |
176 | | - |
177 | | - |
178 | 133 | def test_metadata(): |
179 | 134 | y = da.random.random((10, 10), chunks=(5, 5)) |
180 | 135 | y[y < 0.8] = 0 |
@@ -239,3 +194,18 @@ def test_meta_from_array(): |
239 | 194 | x = sparse.COO.from_numpy(np.eye(1)) |
240 | 195 | y = da.utils.meta_from_array(x, ndim=2) |
241 | 196 | assert isinstance(y, sparse.COO) |
| 197 | + |
| 198 | + |
| 199 | +@pytest.mark.parametrize("numel", [numel, nannumel]) |
| 200 | +@pytest.mark.parametrize("axis", [0, (0, 1), None]) |
| 201 | +@pytest.mark.parametrize("keepdims", [True, False]) |
| 202 | +def test_numel(numel, axis, keepdims): |
| 203 | + x = np.random.random((2, 3, 4)) |
| 204 | + x[x < 0.8] = 0 |
| 205 | + x[x > 0.9] = np.nan |
| 206 | + |
| 207 | + xs = sparse.COO.from_numpy(x, fill_value=0.0) |
| 208 | + |
| 209 | + assert_eq( |
| 210 | + numel(x, axis=axis, keepdims=keepdims), numel(xs, axis=axis, keepdims=keepdims) |
| 211 | + ) |
0 commit comments