Skip to content

Commit 0a7b58b

Browse files
authored
Implement ma.*_like functions (#9378)
Implement `ones_like`, `zeros_like` and `empty_like` within `dask.array.ma`.
1 parent b894f72 commit 0a7b58b

3 files changed

Lines changed: 40 additions & 0 deletions

File tree

dask/array/ma.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,21 @@ def count(a, axis=None, keepdims=False, split_every=None):
190190
split_every=split_every,
191191
out=None,
192192
)
193+
194+
195+
@derived_from(np.ma.core)
196+
def ones_like(a, **kwargs):
197+
a = asanyarray(a)
198+
return a.map_blocks(np.ma.core.ones_like, **kwargs)
199+
200+
201+
@derived_from(np.ma.core)
202+
def zeros_like(a, **kwargs):
203+
a = asanyarray(a)
204+
return a.map_blocks(np.ma.core.zeros_like, **kwargs)
205+
206+
207+
@derived_from(np.ma.core)
208+
def empty_like(a, **kwargs):
209+
a = asanyarray(a)
210+
return a.map_blocks(np.ma.core.empty_like, **kwargs)

dask/array/tests/test_masked.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,22 @@ def test_count():
427427
res = da.ma.count(dx, axis=axis)
428428
sol = np.ma.count(x, axis=axis)
429429
assert_eq(res, sol, check_dtype=sys.platform != "win32")
430+
431+
432+
@pytest.mark.parametrize("funcname", ["ones_like", "zeros_like", "empty_like"])
433+
def test_like_funcs(funcname):
434+
mask = np.array([[True, False], [True, True], [False, True]])
435+
data = np.arange(6).reshape((3, 2))
436+
a = np.ma.array(data, mask=mask)
437+
d_a = da.ma.masked_array(data=data, mask=mask, chunks=2)
438+
439+
da_func = getattr(da.ma, funcname)
440+
np_func = getattr(np.ma.core, funcname)
441+
442+
res = da_func(d_a)
443+
sol = np_func(a)
444+
445+
if "empty" in funcname:
446+
assert_eq(da.ma.getmaskarray(res), np.ma.getmaskarray(sol))
447+
else:
448+
assert_eq(res, sol)

docs/source/array-api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ Masked Arrays
356356
:toctree: generated/
357357

358358
ma.average
359+
ma.empty_like
359360
ma.filled
360361
ma.fix_invalid
361362
ma.getdata
@@ -372,7 +373,9 @@ Masked Arrays
372373
ma.masked_outside
373374
ma.masked_values
374375
ma.masked_where
376+
ma.ones_like
375377
ma.set_fill_value
378+
ma.zeros_like
376379

377380
Random
378381
~~~~~~

0 commit comments

Comments
 (0)