Skip to content

Commit 3f8d092

Browse files
authored
Fix serialization bug in BroadcastJoinLayer (#9871)
1 parent c6b7052 commit 3f8d092

2 files changed

Lines changed: 27 additions & 6 deletions

File tree

dask/layers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,8 @@ def __init__(
872872
rhs_npartitions,
873873
parts_out=None,
874874
annotations=None,
875+
left_on=None,
876+
right_on=None,
875877
**merge_kwargs,
876878
):
877879
super().__init__(annotations=annotations)
@@ -882,14 +884,12 @@ def __init__(
882884
self.rhs_name = rhs_name
883885
self.rhs_npartitions = rhs_npartitions
884886
self.parts_out = parts_out or set(range(self.npartitions))
887+
self.left_on = tuple(left_on) if isinstance(left_on, list) else left_on
888+
self.right_on = tuple(right_on) if isinstance(right_on, list) else right_on
885889
self.merge_kwargs = merge_kwargs
886890
self.how = self.merge_kwargs.get("how")
887-
self.left_on = self.merge_kwargs.get("left_on")
888-
self.right_on = self.merge_kwargs.get("right_on")
889-
if isinstance(self.left_on, list):
890-
self.left_on = (list, tuple(self.left_on))
891-
if isinstance(self.right_on, list):
892-
self.right_on = (list, tuple(self.right_on))
891+
self.merge_kwargs["left_on"] = self.left_on
892+
self.merge_kwargs["right_on"] = self.right_on
893893

894894
def get_output_keys(self):
895895
return {(self.name, part) for part in self.parts_out}

dask/tests/test_distributed.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,27 @@ def test_fused_blockwise_dataframe_merge(c, fuse):
148148
)
149149

150150

151+
@pytest.mark.parametrize("on", ["a", ["a"]])
152+
@pytest.mark.parametrize("broadcast", [True, False])
153+
def test_dataframe_broadcast_merge(c, on, broadcast):
154+
# See: https://github.com/dask/dask/issues/9870
155+
pd = pytest.importorskip("pandas")
156+
dd = pytest.importorskip("dask.dataframe")
157+
158+
pdfl = pd.DataFrame({"a": [1, 2] * 2, "b_left": range(4)})
159+
pdfr = pd.DataFrame({"a": [2, 1], "b_right": range(2)})
160+
dfl = dd.from_pandas(pdfl, npartitions=4)
161+
dfr = dd.from_pandas(pdfr, npartitions=2)
162+
163+
ddfm = dd.merge(dfl, dfr, on=on, broadcast=broadcast, shuffle="tasks")
164+
dfm = ddfm.compute()
165+
dd.utils.assert_eq(
166+
dfm.sort_values("a"),
167+
pd.merge(pdfl, pdfr, on=on).sort_values("a"),
168+
check_index=False,
169+
)
170+
171+
151172
@pytest.mark.parametrize(
152173
"computation",
153174
[

0 commit comments

Comments
 (0)