Skip to content

Commit e429efa

Browse files
authored
fix dask-sql bug caused by blockwise fusion (#8989)
1 parent 0057207 commit e429efa

2 files changed

Lines changed: 38 additions & 2 deletions

File tree

dask/blockwise.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,19 @@ def produces_keys(self) -> bool:
185185
return self._produces_keys
186186

187187
def __getitem__(self, idx: tuple[int, ...]) -> Any:
188-
return self.mapping[idx]
188+
try:
189+
return self.mapping[idx]
190+
except KeyError as err:
191+
# If a DataFrame collection was converted
192+
# to an Array collection, the dimesion of
193+
# `idx` may not agree with the keys in
194+
# `self.mapping`. In this case, we can
195+
# use `self.numblocks` to check for a key
196+
# match in the leading elements of `idx`
197+
flat_idx = idx[: len(self.numblocks)]
198+
if flat_idx in self.mapping:
199+
return self.mapping[flat_idx]
200+
raise err
189201

190202
def __dask_distributed_pack__(
191203
self, required_indices: tuple | list[tuple[int, ...]] | None = None
@@ -204,7 +216,7 @@ def __dask_distributed_pack__(
204216
},
205217
"numblocks": self.numblocks,
206218
"produces_tasks": self.produces_tasks,
207-
"produces_keys": self._produces_keys,
219+
"produces_keys": self.produces_keys,
208220
}
209221

210222
@classmethod

dask/dataframe/io/tests/test_io.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,30 @@ def test_from_delayed_optimize_fusion():
700700
assert len(optimize(ddf.dask, ddf.__dask_keys__()).layers) == 1
701701

702702

703+
def test_from_delayed_to_dask_array():
704+
# Check that `from_delayed`` can be followed
705+
# by `to_dask_array` without breaking
706+
# optimization behavior
707+
# See: https://github.com/dask-contrib/dask-sql/issues/497
708+
from dask.blockwise import optimize_blockwise
709+
710+
dfs = [delayed(pd.DataFrame)(np.ones((3, 2))) for i in range(3)]
711+
ddf = dd.from_delayed(dfs)
712+
arr = ddf.to_dask_array()
713+
714+
# If we optimize this graph without calling
715+
# `fuse_roots`, the underlying `BlockwiseDep`
716+
# `mapping` keys will be 1-D (e.g. `(4,)`),
717+
# while the collection keys will be 2-D
718+
# (e.g. `(4, 0)`)
719+
keys = [k[0] for k in arr.__dask_keys__()]
720+
dsk = optimize_blockwise(arr.dask, keys=keys)
721+
dsk.cull(keys)
722+
723+
result = arr.compute()
724+
assert result.shape == (9, 2)
725+
726+
703727
def test_from_delayed_preserves_hlgs():
704728
df = pd.DataFrame(data=np.random.normal(size=(10, 4)), columns=list("abcd"))
705729
parts = [df.iloc[:1], df.iloc[1:3], df.iloc[3:6], df.iloc[6:10]]

0 commit comments

Comments
 (0)