Skip to content

Commit 2c4fc0a

Browse files
committed
feat: add back rule
1 parent 4550dbd commit 2c4fc0a

File tree

4 files changed

+291
-4
lines changed

4 files changed

+291
-4
lines changed

datafusion/core/tests/dataframe/mod.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,14 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
226226
.collect()
227227
.await?;
228228

229+
let left = pretty_format_batches(&sql_results)?.to_string();
230+
let right = pretty_format_batches(&df_results)?.to_string();
231+
232+
eprintln!("left: {}", left);
233+
eprintln!("right: {}", right);
234+
229235
//make sure sql plan same with df plan
230-
assert_eq!(
231-
pretty_format_batches(&sql_results)?.to_string(),
232-
pretty_format_batches(&df_results)?.to_string()
233-
);
236+
assert_eq!(left, right);
234237

235238
Ok(())
236239
}
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::analyzer::AnalyzerRule;
19+
20+
use crate::utils::NamePreserver;
21+
use datafusion_common::config::ConfigOptions;
22+
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
23+
use datafusion_common::Result;
24+
use datafusion_expr::expr::{
25+
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
26+
};
27+
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
28+
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
29+
30+
/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
31+
///
32+
/// Resolves issue: <https://github.com/apache/datafusion/issues/5473>
33+
#[derive(Default)]
34+
pub struct CountWildcardRule {}
35+
36+
impl CountWildcardRule {
37+
pub fn new() -> Self {
38+
CountWildcardRule {}
39+
}
40+
}
41+
42+
impl AnalyzerRule for CountWildcardRule {
43+
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
44+
plan.transform_down_with_subqueries(analyze_internal).data()
45+
}
46+
47+
fn name(&self) -> &str {
48+
"count_wildcard_rule"
49+
}
50+
}
51+
52+
fn is_wildcard(expr: &Expr) -> bool {
53+
matches!(expr, Expr::Wildcard { qualifier: None })
54+
}
55+
56+
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
57+
matches!(aggregate_function,
58+
AggregateFunction {
59+
func_def: AggregateFunctionDefinition::UDF(udf),
60+
args,
61+
..
62+
} if udf.name() == "count" && args.len() == 1 && is_wildcard(&args[0]))
63+
}
64+
65+
fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
66+
let args = &window_function.args;
67+
matches!(window_function.fun,
68+
WindowFunctionDefinition::AggregateUDF(ref udaf)
69+
if udaf.name() == "count" && args.len() == 1 && is_wildcard(&args[0]))
70+
}
71+
72+
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
73+
let name_preserver = NamePreserver::new(&plan);
74+
plan.map_expressions(|expr| {
75+
let original_name = name_preserver.save(&expr)?;
76+
let transformed_expr = expr.transform_up(|expr| match expr {
77+
Expr::WindowFunction(mut window_function)
78+
if is_count_star_window_aggregate(&window_function) =>
79+
{
80+
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
81+
Ok(Transformed::yes(Expr::WindowFunction(window_function)))
82+
}
83+
Expr::AggregateFunction(mut aggregate_function)
84+
if is_count_star_aggregate(&aggregate_function) =>
85+
{
86+
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
87+
Ok(Transformed::yes(Expr::AggregateFunction(
88+
aggregate_function,
89+
)))
90+
}
91+
_ => Ok(Transformed::no(expr)),
92+
})?;
93+
transformed_expr.map_data(|data| original_name.restore(data))
94+
})
95+
}
96+
97+
#[cfg(test)]
98+
mod tests {
99+
use super::*;
100+
use crate::test::*;
101+
use arrow::datatypes::DataType;
102+
use datafusion_common::ScalarValue;
103+
use datafusion_expr::expr::Sort;
104+
use datafusion_expr::{
105+
col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max,
106+
out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound,
107+
WindowFrameUnits,
108+
};
109+
use datafusion_functions_aggregate::count::count_udaf;
110+
use std::sync::Arc;
111+
112+
use datafusion_functions_aggregate::expr_fn::{count, sum};
113+
114+
fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
115+
assert_analyzed_plan_eq_display_indent(
116+
Arc::new(CountWildcardRule::new()),
117+
plan,
118+
expected,
119+
)
120+
}
121+
122+
#[test]
123+
fn test_count_wildcard_on_sort() -> Result<()> {
124+
let table_scan = test_table_scan()?;
125+
let plan = LogicalPlanBuilder::from(table_scan)
126+
.aggregate(vec![col("b")], vec![count(wildcard())])?
127+
.project(vec![count(wildcard())])?
128+
.sort(vec![count(wildcard()).sort(true, false)])?
129+
.build()?;
130+
let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
131+
\n Projection: count(*) [count(*):Int64]\
132+
\n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\
133+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
134+
assert_plan_eq(plan, expected)
135+
}
136+
137+
#[test]
138+
fn test_count_wildcard_on_where_in() -> Result<()> {
139+
let table_scan_t1 = test_table_scan_with_name("t1")?;
140+
let table_scan_t2 = test_table_scan_with_name("t2")?;
141+
142+
let plan = LogicalPlanBuilder::from(table_scan_t1)
143+
.filter(in_subquery(
144+
col("a"),
145+
Arc::new(
146+
LogicalPlanBuilder::from(table_scan_t2)
147+
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])?
148+
.project(vec![count(wildcard())])?
149+
.build()?,
150+
),
151+
))?
152+
.build()?;
153+
154+
let expected = "Filter: t1.a IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
155+
\n Subquery: [count(*):Int64]\
156+
\n Projection: count(*) [count(*):Int64]\
157+
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\
158+
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
159+
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
160+
assert_plan_eq(plan, expected)
161+
}
162+
163+
#[test]
164+
fn test_count_wildcard_on_where_exists() -> Result<()> {
165+
let table_scan_t1 = test_table_scan_with_name("t1")?;
166+
let table_scan_t2 = test_table_scan_with_name("t2")?;
167+
168+
let plan = LogicalPlanBuilder::from(table_scan_t1)
169+
.filter(exists(Arc::new(
170+
LogicalPlanBuilder::from(table_scan_t2)
171+
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])?
172+
.project(vec![count(wildcard())])?
173+
.build()?,
174+
)))?
175+
.build()?;
176+
177+
let expected = "Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
178+
\n Subquery: [count(*):Int64]\
179+
\n Projection: count(*) [count(*):Int64]\
180+
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\
181+
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
182+
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
183+
assert_plan_eq(plan, expected)
184+
}
185+
186+
#[test]
187+
fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
188+
let table_scan_t1 = test_table_scan_with_name("t1")?;
189+
let table_scan_t2 = test_table_scan_with_name("t2")?;
190+
191+
let plan = LogicalPlanBuilder::from(table_scan_t1)
192+
.filter(
193+
scalar_subquery(Arc::new(
194+
LogicalPlanBuilder::from(table_scan_t2)
195+
.filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))?
196+
.aggregate(
197+
Vec::<Expr>::new(),
198+
vec![count(lit(COUNT_STAR_EXPANSION))],
199+
)?
200+
.project(vec![count(lit(COUNT_STAR_EXPANSION))])?
201+
.build()?,
202+
))
203+
.gt(lit(ScalarValue::UInt8(Some(0)))),
204+
)?
205+
.project(vec![col("t1.a"), col("t1.b")])?
206+
.build()?;
207+
208+
let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\
209+
\n Filter: (<subquery>) > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\
210+
\n Subquery: [count(Int64(1)):Int64]\
211+
\n Projection: count(Int64(1)) [count(Int64(1)):Int64]\
212+
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64]\
213+
\n Filter: outer_ref(t1.a) = t2.a [a:UInt32, b:UInt32, c:UInt32]\
214+
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
215+
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
216+
assert_plan_eq(plan, expected)
217+
}
218+
#[test]
219+
fn test_count_wildcard_on_window() -> Result<()> {
220+
let table_scan = test_table_scan()?;
221+
222+
let plan = LogicalPlanBuilder::from(table_scan)
223+
.window(vec![Expr::WindowFunction(expr::WindowFunction::new(
224+
WindowFunctionDefinition::AggregateUDF(count_udaf()),
225+
vec![wildcard()],
226+
vec![],
227+
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
228+
WindowFrame::new_bounds(
229+
WindowFrameUnits::Range,
230+
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
231+
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
232+
),
233+
None,
234+
))])?
235+
.project(vec![count(wildcard())])?
236+
.build()?;
237+
238+
let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\
239+
\n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\
240+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
241+
assert_plan_eq(plan, expected)
242+
}
243+
244+
#[test]
245+
fn test_count_wildcard_on_aggregate() -> Result<()> {
246+
let table_scan = test_table_scan()?;
247+
let plan = LogicalPlanBuilder::from(table_scan)
248+
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])?
249+
.project(vec![count(wildcard())])?
250+
.build()?;
251+
252+
let expected = "Projection: count(*) [count(*):Int64]\
253+
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\
254+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
255+
assert_plan_eq(plan, expected)
256+
}
257+
258+
#[test]
259+
fn test_count_wildcard_on_non_count_aggregate() -> Result<()> {
260+
let table_scan = test_table_scan()?;
261+
let res = LogicalPlanBuilder::from(table_scan)
262+
.aggregate(Vec::<Expr>::new(), vec![sum(wildcard())]);
263+
assert!(res.is_err());
264+
Ok(())
265+
}
266+
267+
#[test]
268+
fn test_count_wildcard_on_nesting() -> Result<()> {
269+
let table_scan = test_table_scan()?;
270+
let plan = LogicalPlanBuilder::from(table_scan)
271+
.aggregate(Vec::<Expr>::new(), vec![max(count(wildcard()))])?
272+
.project(vec![count(wildcard())])?
273+
.build()?;
274+
275+
let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\
276+
\n Aggregate: groupBy=[[]], aggr=[[MAX(count(Int64(1))) AS MAX(count(*))]] [MAX(count(*)):Int64;N]\
277+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
278+
assert_plan_eq(plan, expected)
279+
}
280+
}

datafusion/optimizer/src/analyzer/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ use datafusion_expr::expr::InSubquery;
2929
use datafusion_expr::expr_rewriter::FunctionRewrite;
3030
use datafusion_expr::{Expr, LogicalPlan};
3131

32+
use crate::analyzer::count_wildcard_rule::CountWildcardRule;
3233
use crate::analyzer::inline_table_scan::InlineTableScan;
3334
use crate::analyzer::subquery::check_subquery_expr;
3435
use crate::analyzer::type_coercion::TypeCoercion;
3536
use crate::utils::log_plan;
3637

3738
use self::function_rewrite::ApplyFunctionRewrites;
3839

40+
pub mod count_wildcard_rule;
3941
pub mod function_rewrite;
4042
pub mod inline_table_scan;
4143
pub mod subquery;
@@ -88,6 +90,7 @@ impl Analyzer {
8890
let rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>> = vec![
8991
Arc::new(InlineTableScan::new()),
9092
Arc::new(TypeCoercion::new()),
93+
Arc::new(CountWildcardRule::new()),
9194
];
9295
Self::with_rules(rules)
9396
}

datafusion/sqllogictest/test_files/explain.slt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ initial_logical_plan
182182
logical_plan after apply_function_rewrites SAME TEXT AS ABOVE
183183
logical_plan after inline_table_scan SAME TEXT AS ABOVE
184184
logical_plan after type_coercion SAME TEXT AS ABOVE
185+
logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
185186
analyzed_logical_plan SAME TEXT AS ABOVE
186187
logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
187188
logical_plan after simplify_expressions SAME TEXT AS ABOVE

0 commit comments

Comments
 (0)