Skip to content

Commit 5416341

Browse files
authored
JOIN conditions are order dependent (#778)
* allow either order joins * refactor to individual condition level * change join signature to 'join_keys' tuple
1 parent f036f18 commit 5416341

File tree

8 files changed

+108
-48
lines changed

8 files changed

+108
-48
lines changed

ballista/rust/core/src/serde/logical_plan/from_proto.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,7 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
272272
JoinConstraint::On => builder.join(
273273
&convert_box_required!(join.right)?,
274274
join_type.into(),
275-
left_keys,
276-
right_keys,
275+
(left_keys, right_keys),
277276
)?,
278277
JoinConstraint::Using => builder.join_using(
279278
&convert_box_required!(join.right)?,

ballista/rust/core/src/serde/logical_plan/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ mod roundtrip_tests {
701701
CsvReadOptions::new().schema(&schema).has_header(true),
702702
Some(vec![0, 3, 4]),
703703
)
704-
.and_then(|plan| plan.join(&scan_plan, JoinType::Inner, vec!["id"], vec!["id"]))
704+
.and_then(|plan| plan.join(&scan_plan, JoinType::Inner, (vec!["id"], vec!["id"])))
705705
.and_then(|plan| plan.build())
706706
.map_err(BallistaError::DataFusionError)?;
707707

datafusion/src/execution/dataframe_impl.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ impl DataFrame for DataFrameImpl {
117117
.join(
118118
&right.to_logical_plan(),
119119
join_type,
120-
left_cols.to_vec(),
121-
right_cols.to_vec(),
120+
(left_cols.to_vec(), right_cols.to_vec()),
122121
)?
123122
.build()?;
124123
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))

datafusion/src/logical_plan/builder.rs

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -273,23 +273,37 @@ impl LogicalPlanBuilder {
273273
&self,
274274
right: &LogicalPlan,
275275
join_type: JoinType,
276-
left_keys: Vec<impl Into<Column>>,
277-
right_keys: Vec<impl Into<Column>>,
276+
join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
278277
) -> Result<Self> {
279-
if left_keys.len() != right_keys.len() {
278+
if join_keys.0.len() != join_keys.1.len() {
280279
return Err(DataFusionError::Plan(
281280
"left_keys and right_keys were not the same length".to_string(),
282281
));
283282
}
284283

285-
let left_keys: Vec<Column> = left_keys
286-
.into_iter()
287-
.map(|c| c.into().normalize(&self.plan))
288-
.collect::<Result<_>>()?;
289-
let right_keys: Vec<Column> = right_keys
290-
.into_iter()
291-
.map(|c| c.into().normalize(right))
292-
.collect::<Result<_>>()?;
284+
let (left_keys, right_keys): (Vec<Result<Column>>, Vec<Result<Column>>) =
285+
join_keys
286+
.0
287+
.into_iter()
288+
.zip(join_keys.1.into_iter())
289+
.map(|(l, r)| {
290+
let mut swap = false;
291+
let l = l.into();
292+
let left_key = l.clone().normalize(&self.plan).or_else(|_| {
293+
swap = true;
294+
l.normalize(right)
295+
});
296+
if swap {
297+
(r.into().normalize(&self.plan), left_key)
298+
} else {
299+
(left_key, r.into().normalize(right))
300+
}
301+
})
302+
.unzip();
303+
304+
let left_keys = left_keys.into_iter().collect::<Result<Vec<Column>>>()?;
305+
let right_keys = right_keys.into_iter().collect::<Result<Vec<Column>>>()?;
306+
293307
let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
294308
let join_schema =
295309
build_join_schema(self.plan.schema(), right.schema(), &join_type)?;

datafusion/src/optimizer/filter_push_down.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -973,8 +973,7 @@ mod tests {
973973
.join(
974974
&right,
975975
JoinType::Inner,
976-
vec![Column::from_name("a")],
977-
vec![Column::from_name("a")],
976+
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
978977
)?
979978
.filter(col("a").lt_eq(lit(1i64)))?
980979
.build()?;
@@ -1058,8 +1057,7 @@ mod tests {
10581057
.join(
10591058
&right,
10601059
JoinType::Inner,
1061-
vec![Column::from_name("a")],
1062-
vec![Column::from_name("a")],
1060+
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
10631061
)?
10641062
// "b" and "c" are not shared by either side: they are only available together after the join
10651063
.filter(col("c").lt_eq(col("b")))?
@@ -1099,8 +1097,7 @@ mod tests {
10991097
.join(
11001098
&right,
11011099
JoinType::Inner,
1102-
vec![Column::from_name("a")],
1103-
vec![Column::from_name("a")],
1100+
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
11041101
)?
11051102
.filter(col("b").lt_eq(lit(1i64)))?
11061103
.build()?;

datafusion/src/optimizer/projection_push_down.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ mod tests {
555555
LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?;
556556

557557
let plan = LogicalPlanBuilder::from(table_scan)
558-
.join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])?
558+
.join(&table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]))?
559559
.project(vec![col("a"), col("b"), col("c1")])?
560560
.build()?;
561561

@@ -594,7 +594,7 @@ mod tests {
594594
LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?;
595595

596596
let plan = LogicalPlanBuilder::from(table_scan)
597-
.join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])?
597+
.join(&table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]))?
598598
// projecting joined column `a` should push the right side column `c1` projection as
599599
// well into test2 table even though `c1` is not referenced in projection.
600600
.project(vec![col("a"), col("b")])?

datafusion/src/sql/planner.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
375375
let (left_keys, right_keys): (Vec<Column>, Vec<Column>) =
376376
keys.into_iter().unzip();
377377
// return the logical plan representing the join
378-
let join = LogicalPlanBuilder::from(left)
379-
.join(right, join_type, left_keys, right_keys)?;
378+
let join = LogicalPlanBuilder::from(left).join(
379+
right,
380+
join_type,
381+
(left_keys, right_keys),
382+
)?;
380383

381384
if filter.is_empty() {
382385
join.build()
@@ -548,7 +551,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
548551
join_keys.iter().map(|(_, r)| r.clone()).collect();
549552
let builder = LogicalPlanBuilder::from(left);
550553
left = builder
551-
.join(right, JoinType::Inner, left_keys, right_keys)?
554+
.join(right, JoinType::Inner, (left_keys, right_keys))?
552555
.build()?;
553556
}
554557

datafusion/tests/sql.rs

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,15 +1717,40 @@ fn create_case_context() -> Result<ExecutionContext> {
17171717
#[tokio::test]
17181718
async fn equijoin() -> Result<()> {
17191719
let mut ctx = create_join_context("t1_id", "t2_id")?;
1720-
let sql =
1721-
"SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id";
1722-
let actual = execute(&mut ctx, sql).await;
1720+
let equivalent_sql = [
1721+
"SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id",
1722+
"SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id ORDER BY t1_id",
1723+
];
17231724
let expected = vec![
17241725
vec!["11", "a", "z"],
17251726
vec!["22", "b", "y"],
17261727
vec!["44", "d", "x"],
17271728
];
1728-
assert_eq!(expected, actual);
1729+
for sql in equivalent_sql.iter() {
1730+
let actual = execute(&mut ctx, sql).await;
1731+
assert_eq!(expected, actual);
1732+
}
1733+
Ok(())
1734+
}
1735+
1736+
#[tokio::test]
1737+
async fn equijoin_multiple_condition_ordering() -> Result<()> {
1738+
let mut ctx = create_join_context("t1_id", "t2_id")?;
1739+
let equivalent_sql = [
1740+
"SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name ORDER BY t1_id",
1741+
"SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name <> t1_name ORDER BY t1_id",
1742+
"SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t1_name <> t2_name ORDER BY t1_id",
1743+
"SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t2_name <> t1_name ORDER BY t1_id",
1744+
];
1745+
let expected = vec![
1746+
vec!["11", "a", "z"],
1747+
vec!["22", "b", "y"],
1748+
vec!["44", "d", "x"],
1749+
];
1750+
for sql in equivalent_sql.iter() {
1751+
let actual = execute(&mut ctx, sql).await;
1752+
assert_eq!(expected, actual);
1753+
}
17291754
Ok(())
17301755
}
17311756

@@ -1754,51 +1779,70 @@ async fn equijoin_and_unsupported_condition() -> Result<()> {
17541779
#[tokio::test]
17551780
async fn left_join() -> Result<()> {
17561781
let mut ctx = create_join_context("t1_id", "t2_id")?;
1757-
let sql = "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id";
1758-
let actual = execute(&mut ctx, sql).await;
1782+
let equivalent_sql = [
1783+
"SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id",
1784+
"SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id",
1785+
];
17591786
let expected = vec![
17601787
vec!["11", "a", "z"],
17611788
vec!["22", "b", "y"],
17621789
vec!["33", "c", "NULL"],
17631790
vec!["44", "d", "x"],
17641791
];
1765-
assert_eq!(expected, actual);
1792+
for sql in equivalent_sql.iter() {
1793+
let actual = execute(&mut ctx, sql).await;
1794+
assert_eq!(expected, actual);
1795+
}
17661796
Ok(())
17671797
}
17681798

17691799
#[tokio::test]
17701800
async fn right_join() -> Result<()> {
17711801
let mut ctx = create_join_context("t1_id", "t2_id")?;
1772-
let sql =
1773-
"SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id ORDER BY t1_id";
1774-
let actual = execute(&mut ctx, sql).await;
1802+
let equivalent_sql = [
1803+
"SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id ORDER BY t1_id",
1804+
"SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t2_id = t1_id ORDER BY t1_id"
1805+
];
17751806
let expected = vec![
17761807
vec!["NULL", "NULL", "w"],
17771808
vec!["11", "a", "z"],
17781809
vec!["22", "b", "y"],
17791810
vec!["44", "d", "x"],
17801811
];
1781-
assert_eq!(expected, actual);
1812+
for sql in equivalent_sql.iter() {
1813+
let actual = execute(&mut ctx, sql).await;
1814+
assert_eq!(expected, actual);
1815+
}
17821816
Ok(())
17831817
}
17841818

17851819
#[tokio::test]
17861820
async fn full_join() -> Result<()> {
17871821
let mut ctx = create_join_context("t1_id", "t2_id")?;
1788-
let sql = "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id ORDER BY t1_id";
1789-
let actual = execute(&mut ctx, sql).await;
1822+
let equivalent_sql = [
1823+
"SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id ORDER BY t1_id",
1824+
"SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t2_id = t1_id ORDER BY t1_id",
1825+
];
17901826
let expected = vec![
17911827
vec!["NULL", "NULL", "w"],
17921828
vec!["11", "a", "z"],
17931829
vec!["22", "b", "y"],
17941830
vec!["33", "c", "NULL"],
17951831
vec!["44", "d", "x"],
17961832
];
1797-
assert_eq!(expected, actual);
1833+
for sql in equivalent_sql.iter() {
1834+
let actual = execute(&mut ctx, sql).await;
1835+
assert_eq!(expected, actual);
1836+
}
17981837

1799-
let sql = "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id ORDER BY t1_id";
1800-
let actual = execute(&mut ctx, sql).await;
1801-
assert_eq!(expected, actual);
1838+
let equivalent_sql = [
1839+
"SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id ORDER BY t1_id",
1840+
"SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t2_id = t1_id ORDER BY t1_id",
1841+
];
1842+
for sql in equivalent_sql.iter() {
1843+
let actual = execute(&mut ctx, sql).await;
1844+
assert_eq!(expected, actual);
1845+
}
18021846

18031847
Ok(())
18041848
}
@@ -1821,15 +1865,19 @@ async fn left_join_using() -> Result<()> {
18211865
#[tokio::test]
18221866
async fn equijoin_implicit_syntax() -> Result<()> {
18231867
let mut ctx = create_join_context("t1_id", "t2_id")?;
1824-
let sql =
1825-
"SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id ORDER BY t1_id";
1826-
let actual = execute(&mut ctx, sql).await;
1868+
let equivalent_sql = [
1869+
"SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id ORDER BY t1_id",
1870+
"SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id",
1871+
];
18271872
let expected = vec![
18281873
vec!["11", "a", "z"],
18291874
vec!["22", "b", "y"],
18301875
vec!["44", "d", "x"],
18311876
];
1832-
assert_eq!(expected, actual);
1877+
for sql in equivalent_sql.iter() {
1878+
let actual = execute(&mut ctx, sql).await;
1879+
assert_eq!(expected, actual);
1880+
}
18331881
Ok(())
18341882
}
18351883

0 commit comments

Comments
 (0)