Skip to content

Commit 19a5147

Browse files
authored
Enable automatic column projection for groupby aggregations (#9442)
1 parent 222ea06 commit 19a5147

2 files changed

Lines changed: 24 additions & 3 deletions

File tree

dask/dataframe/groupby.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,7 @@ def get_group(self, key):
16541654

16551655
@_aggregate_docstring()
16561656
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
1657+
column_projection = None
16571658
if isinstance(self.obj, DataFrame):
16581659
if isinstance(self.by, tuple) or np.isscalar(self.by):
16591660
group_columns = {self.by}
@@ -1681,6 +1682,10 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
16811682

16821683
spec = _normalize_spec(arg, non_group_columns)
16831684

1685+
# Check if the aggregation involves implicit column projection
1686+
if isinstance(arg, dict):
1687+
column_projection = group_columns | arg.keys()
1688+
16841689
elif isinstance(self.obj, Series):
16851690
if isinstance(arg, (list, tuple, dict)):
16861691
# implementation detail: if self.obj is a series, a pseudo column
@@ -1709,11 +1714,17 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
17091714
else:
17101715
levels = 0
17111716

1717+
# Add an explicit `getitem` operation if the groupby
1718+
# aggregation involves implicit column projection.
1719+
# This makes it possible for the column-projection
1720+
# to be pushed into the IO layer
1721+
_obj = self.obj[list(column_projection)] if column_projection else self.obj
1722+
17121723
if not isinstance(self.by, list):
1713-
chunk_args = [self.obj, self.by]
1724+
chunk_args = [_obj, self.by]
17141725

17151726
else:
1716-
chunk_args = [self.obj] + self.by
1727+
chunk_args = [_obj] + self.by
17171728

17181729
if not PANDAS_GT_110 and self.dropna:
17191730
raise NotImplementedError(

dask/dataframe/tests/test_groupby.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from dask.dataframe.backends import grouper_dispatch
1515
from dask.dataframe.utils import assert_dask_graph, assert_eq, assert_max_deps
1616
from dask.utils import M
17+
from dask.utils_test import hlg_layer
1718

1819
CHECK_FREQ = {}
1920
if dd._compat.PANDAS_GT_110:
@@ -1080,6 +1081,11 @@ def test_aggregate_dask():
10801081
assert_max_deps(agg_dask1, 2)
10811082
assert_max_deps(agg_dask2, 2)
10821083

1084+
# Make sure dict-based aggregation specs result in an
1085+
# explicit `getitem` layer to improve column projection
1086+
if isinstance(spec, dict):
1087+
assert hlg_layer(result1.dask, "getitem")
1088+
10831089
# check for deterministic key names and values.
10841090
# Require pickle since "partial" concat functions
10851091
# used in tree-reduction cannot be compared
@@ -1090,7 +1096,11 @@ def test_aggregate_dask():
10901096
# Note: List-based aggregation specs may result in
10911097
# an extra delayed layer. This is because a "long" list
10921098
# arg will be detected in `dask.array.core.normalize_arg`.
1093-
if isinstance(spec, list) == isinstance(other_spec, list):
1099+
# Also, dict-based aggregation specs will result in
1100+
# an extra `getitem` layer (to improve column projection)
1101+
if (isinstance(spec, list) == isinstance(other_spec, list)) and (
1102+
isinstance(spec, dict) == isinstance(other_spec, dict)
1103+
):
10941104
other = ddf.groupby(["a", "b"]).agg(other_spec, split_every=2)
10951105
assert len(other.dask) == len(result1.dask)
10961106
assert len(other.dask) == len(result2.dask)

0 commit comments

Comments
 (0)