Skip to content

Commit 3f7aa3f

Browse files
Extend __setitem__ to more closely match numpy (#7033)
* First implementation of extended __setitem__ * style and typo * unite test, and fixes to make it pass. Reinstated 'where' for cases when key is Array * test ValueErrors are raised for disallowed index combintations * style * Comments * Extra tests: assignment of N-d arrays and brodacasting * Allow assignment to np.ma.masked, and also to read-only arrays (e.g. da.empty) * ValueError -> NotImplementedError * numpy style docstring Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Doc strings; value_shape1 renamed to value_common_shape * Update message format Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Update message format Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * test more non-valid assignments * Style * Style * Prevent corruption of chunks shared between objects - allows, e.g. d[:] = d[::-1] * Fix f-string Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Prevent hanging when value is self - allows, e.g. d[...] = d and d[...] = d[...] * Fix for integer-valued indices - allows, e.g. d[:, 0] = d[:, 2] * Restructure code, catch illegal corner cases, remove premature compute * Reorganise imports * setitem assignment docs * fixed labels, add to toctree Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
1 parent 058aef6 commit 3f7aa3f

File tree

5 files changed

+807
-8
lines changed

5 files changed

+807
-8
lines changed

dask/array/core.py

Lines changed: 132 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,13 @@
5959
from ..sizeof import sizeof
6060
from ..highlevelgraph import HighLevelGraph
6161
from .numpy_compat import _Recurser, _make_sliced_dtype
62-
from .slicing import slice_array, replace_ellipsis, cached_cumsum
62+
from .slicing import (
63+
slice_array,
64+
replace_ellipsis,
65+
cached_cumsum,
66+
parse_assignment_indices,
67+
setitem,
68+
)
6369
from .blockwise import blockwise
6470
from .chunk_types import is_valid_array_chunk, is_valid_chunk_type
6571

@@ -1605,12 +1611,17 @@ def __float__(self):
16051611
def __complex__(self):
16061612
return self._scalarfunc(complex)
16071613

1608-
def __setitem__(self, key, value):
1609-
from .routines import where
1614+
def __index__(self):
1615+
return self._scalarfunc(int)
16101616

1617+
def __setitem__(self, key, value):
1618+
# Use the "where" method for cases when key is an Array
16111619
if isinstance(key, Array):
16121620
if isinstance(value, Array) and value.ndim > 1:
16131621
raise ValueError("boolean index array should have 1 dimension")
1622+
1623+
from .routines import where
1624+
16141625
try:
16151626
y = where(key, value, self)
16161627
except ValueError as e:
@@ -1625,11 +1636,126 @@ def __setitem__(self, key, value):
16251636
self.name = y.name
16261637
self._chunks = y.chunks
16271638
return self
1628-
else:
1629-
raise NotImplementedError(
1630-
"Item assignment with %s not supported" % type(key)
1639+
1640+
# Still here? Then parse the indices from 'key' and apply the
1641+
# assignment via map_blocks
1642+
1643+
# Reformat input indices
1644+
indices, indices_shape, mirror = parse_assignment_indices(key, self.shape)
1645+
1646+
# Cast 'value' as a dask array
1647+
if value is np.ma.masked:
1648+
# Convert masked to a scalar masked array
1649+
value = np.ma.array(0, mask=True)
1650+
1651+
if value is self:
1652+
# Need to copy value if it is the same as self. This
1653+
# allows x[...] = x and x[...] = x[...], etc.
1654+
value = value.copy()
1655+
1656+
value = asanyarray(value)
1657+
value_shape = value.shape
1658+
1659+
if 0 in indices_shape and value.size > 1:
1660+
# Can only assign size 1 values (with any number of
1661+
# dimensions) to empty slices
1662+
raise ValueError(
1663+
f"shape mismatch: value array of shape {value_shape} "
1664+
"could not be broadcast to indexing result "
1665+
f"of shape {tuple(indices_shape)}"
16311666
)
16321667

1668+
# Define:
1669+
#
1670+
# offset: The difference in the relative positions of a
1671+
# dimension in 'value' and the corresponding
1672+
# dimension in self. A positive value means the
1673+
# dimension position is further to the right in self
1674+
# than 'value'.
1675+
#
1676+
# self_common_shape: The shape of those dimensions of self
1677+
# which correspond to dimensions of
1678+
# 'value'.
1679+
#
1680+
# value_common_shape: The shape of those dimensions of
1681+
# 'value' which correspond to dimensions
1682+
# of self.
1683+
#
1684+
# base_value_indices: The indices used for initialising the
1685+
# selection from 'value'. slice(None)
1686+
# elements are unchanged, but an element
1687+
# of None will, inside a call to setitem,
1688+
# be replaced by an appropriate slice.
1689+
#
1690+
# Note that self_common_shape and value_common_shape may be
1691+
# different if there are any size 1 dimensions are being
1692+
# brodacast.
1693+
offset = len(indices_shape) - value.ndim
1694+
if offset >= 0:
1695+
# self has the same number or more dimensions than 'value'
1696+
self_common_shape = indices_shape[offset:]
1697+
value_common_shape = value_shape
1698+
1699+
# Modify the mirror dimensions with the offset
1700+
mirror = [i - offset for i in mirror if i >= offset]
1701+
else:
1702+
# 'value' has more dimensions than self
1703+
value_offset = -offset
1704+
if value_shape[:value_offset] != [1] * value_offset:
1705+
# Can only allow 'value' to have more dimensions then
1706+
# self if the extra leading dimensions all have size
1707+
# 1.
1708+
raise ValueError(
1709+
"could not broadcast input array from shape"
1710+
f"{value_shape} into shape {tuple(indices_shape)}"
1711+
)
1712+
1713+
offset = 0
1714+
self_common_shape = indices_shape
1715+
value_common_shape = value_shape[value_offset:]
1716+
1717+
# Find out which of the dimensions of 'value' are to be
1718+
# broadcast across self.
1719+
#
1720+
# Note that, as in numpy, it is not allowed for a dimension in
1721+
# 'value' to be larger than a size 1 dimension in self
1722+
base_value_indices = []
1723+
non_broadcast_dimensions = []
1724+
for i, (a, b) in enumerate(zip(self_common_shape, value_common_shape)):
1725+
if b == 1:
1726+
base_value_indices.append(slice(None))
1727+
elif a == b and b != 1:
1728+
base_value_indices.append(None)
1729+
non_broadcast_dimensions.append(i)
1730+
elif a is None and b != 1:
1731+
base_value_indices.append(None)
1732+
non_broadcast_dimensions.append(i)
1733+
elif a is not None:
1734+
# Can't check ...
1735+
raise ValueError(
1736+
f"Can't broadcast data with shape {value_common_shape} "
1737+
f"across shape {tuple(indices_shape)}"
1738+
)
1739+
1740+
# Map the setitem function across all blocks
1741+
y = self.map_blocks(
1742+
partial(setitem, value=value),
1743+
dtype=self.dtype,
1744+
indices=indices,
1745+
non_broadcast_dimensions=non_broadcast_dimensions,
1746+
offset=offset,
1747+
base_value_indices=base_value_indices,
1748+
mirror=mirror,
1749+
value_common_shape=value_common_shape,
1750+
)
1751+
1752+
self._meta = y._meta
1753+
self.dask = y.dask
1754+
self.name = y.name
1755+
self._chunks = y.chunks
1756+
1757+
return self
1758+
16331759
def __getitem__(self, index):
16341760
# Field access, e.g. x['a'] or x[['a', 'b']]
16351761
if isinstance(index, str) or (

0 commit comments

Comments
 (0)