Skip to content

Commit edfa91e

Browse files
authored
Return dask array if all axes are squeezed (#9250)
* return 0d array if all axes are squeezed * update if logic
1 parent 84533ea commit edfa91e

2 files changed

Lines changed: 18 additions & 0 deletions

File tree

dask/array/routines.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,6 +1934,13 @@ def squeeze(a, axis=None):
19341934

19351935
sl = tuple(0 if i in axis else slice(None) for i, s in enumerate(a.shape))
19361936

1937+
# Return 0d Dask Array if all axes are squeezed,
1938+
# to be consistent with NumPy. Ref: https://github.com/dask/dask/issues/9183#issuecomment-1155626619
1939+
if all(s == 0 for s in sl) and all(s == 1 for s in a.shape):
1940+
return a.map_blocks(
1941+
np.squeeze, meta=a._meta, drop_axis=tuple(range(len(a.shape)))
1942+
)
1943+
19371944
a = a[sl]
19381945

19391946
return a

dask/array/tests/test_routines.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,17 @@ def test_squeeze(is_func, axis):
14891489
assert d_s.chunks == exp_d_s_chunks
14901490

14911491

1492+
@pytest.mark.parametrize("shape", [(1,), (1, 1)])
1493+
def test_squeeze_1d_array(shape):
1494+
a = np.full(shape=shape, fill_value=2)
1495+
a_s = np.squeeze(a)
1496+
d = da.from_array(a, chunks=(1))
1497+
d_s = da.squeeze(d)
1498+
assert isinstance(d_s, da.Array)
1499+
assert isinstance(d_s.compute(), np.ndarray)
1500+
assert_eq(d_s, a_s)
1501+
1502+
14921503
def test_vstack():
14931504
x = np.arange(5)
14941505
y = np.ones(5)

0 commit comments

Comments
 (0)