|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +use crate::analyzer::AnalyzerRule; |
| 19 | + |
| 20 | +use crate::utils::NamePreserver; |
| 21 | +use datafusion_common::config::ConfigOptions; |
| 22 | +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; |
| 23 | +use datafusion_common::Result; |
| 24 | +use datafusion_expr::expr::{ |
| 25 | + AggregateFunction, AggregateFunctionDefinition, WindowFunction, |
| 26 | +}; |
| 27 | +use datafusion_expr::utils::COUNT_STAR_EXPANSION; |
| 28 | +use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; |
| 29 | + |
| 30 | +/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. |
| 31 | +/// |
| 32 | +/// Resolves issue: <https://github.com/apache/datafusion/issues/5473> |
| 33 | +#[derive(Default)] |
| 34 | +pub struct CountWildcardRule {} |
| 35 | + |
| 36 | +impl CountWildcardRule { |
| 37 | + pub fn new() -> Self { |
| 38 | + CountWildcardRule {} |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | +impl AnalyzerRule for CountWildcardRule { |
| 43 | + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> { |
| 44 | + plan.transform_down_with_subqueries(analyze_internal).data() |
| 45 | + } |
| 46 | + |
| 47 | + fn name(&self) -> &str { |
| 48 | + "count_wildcard_rule" |
| 49 | + } |
| 50 | +} |
| 51 | + |
| 52 | +fn is_wildcard(expr: &Expr) -> bool { |
| 53 | + matches!(expr, Expr::Wildcard { qualifier: None }) |
| 54 | +} |
| 55 | + |
| 56 | +fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { |
| 57 | + matches!(aggregate_function, |
| 58 | + AggregateFunction { |
| 59 | + func_def: AggregateFunctionDefinition::UDF(udf), |
| 60 | + args, |
| 61 | + .. |
| 62 | + } if udf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) |
| 63 | +} |
| 64 | + |
| 65 | +fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { |
| 66 | + let args = &window_function.args; |
| 67 | + matches!(window_function.fun, |
| 68 | + WindowFunctionDefinition::AggregateUDF(ref udaf) |
| 69 | + if udaf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) |
| 70 | +} |
| 71 | + |
| 72 | +fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> { |
| 73 | + let name_preserver = NamePreserver::new(&plan); |
| 74 | + plan.map_expressions(|expr| { |
| 75 | + let original_name = name_preserver.save(&expr)?; |
| 76 | + let transformed_expr = expr.transform_up(|expr| match expr { |
| 77 | + Expr::WindowFunction(mut window_function) |
| 78 | + if is_count_star_window_aggregate(&window_function) => |
| 79 | + { |
| 80 | + window_function.args = vec![lit(COUNT_STAR_EXPANSION)]; |
| 81 | + Ok(Transformed::yes(Expr::WindowFunction(window_function))) |
| 82 | + } |
| 83 | + Expr::AggregateFunction(mut aggregate_function) |
| 84 | + if is_count_star_aggregate(&aggregate_function) => |
| 85 | + { |
| 86 | + aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; |
| 87 | + Ok(Transformed::yes(Expr::AggregateFunction( |
| 88 | + aggregate_function, |
| 89 | + ))) |
| 90 | + } |
| 91 | + _ => Ok(Transformed::no(expr)), |
| 92 | + })?; |
| 93 | + transformed_expr.map_data(|data| original_name.restore(data)) |
| 94 | + }) |
| 95 | +} |
| 96 | + |
| 97 | +#[cfg(test)] |
| 98 | +mod tests { |
| 99 | + use super::*; |
| 100 | + use crate::test::*; |
| 101 | + use arrow::datatypes::DataType; |
| 102 | + use datafusion_common::ScalarValue; |
| 103 | + use datafusion_expr::expr::Sort; |
| 104 | + use datafusion_expr::{ |
| 105 | + col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, |
| 106 | + out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, |
| 107 | + WindowFrameUnits, |
| 108 | + }; |
| 109 | + use datafusion_functions_aggregate::count::count_udaf; |
| 110 | + use std::sync::Arc; |
| 111 | + |
| 112 | + use datafusion_functions_aggregate::expr_fn::{count, sum}; |
| 113 | + |
| 114 | + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { |
| 115 | + assert_analyzed_plan_eq_display_indent( |
| 116 | + Arc::new(CountWildcardRule::new()), |
| 117 | + plan, |
| 118 | + expected, |
| 119 | + ) |
| 120 | + } |
| 121 | + |
| 122 | + #[test] |
| 123 | + fn test_count_wildcard_on_sort() -> Result<()> { |
| 124 | + let table_scan = test_table_scan()?; |
| 125 | + let plan = LogicalPlanBuilder::from(table_scan) |
| 126 | + .aggregate(vec![col("b")], vec![count(wildcard())])? |
| 127 | + .project(vec![count(wildcard())])? |
| 128 | + .sort(vec![count(wildcard()).sort(true, false)])? |
| 129 | + .build()?; |
| 130 | + let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\ |
| 131 | + \n Projection: count(*) [count(*):Int64]\ |
| 132 | + \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\ |
| 133 | + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; |
| 134 | + assert_plan_eq(plan, expected) |
| 135 | + } |
| 136 | + |
| 137 | + #[test] |
| 138 | + fn test_count_wildcard_on_where_in() -> Result<()> { |
| 139 | + let table_scan_t1 = test_table_scan_with_name("t1")?; |
| 140 | + let table_scan_t2 = test_table_scan_with_name("t2")?; |
| 141 | + |
| 142 | + let plan = LogicalPlanBuilder::from(table_scan_t1) |
| 143 | + .filter(in_subquery( |
| 144 | + col("a"), |
| 145 | + Arc::new( |
| 146 | + LogicalPlanBuilder::from(table_scan_t2) |
| 147 | + .aggregate(Vec::<Expr>::new(), vec![count(wildcard())])? |
| 148 | + .project(vec![count(wildcard())])? |
| 149 | + .build()?, |
| 150 | + ), |
| 151 | + ))? |
| 152 | + .build()?; |
| 153 | + |
| 154 | + let expected = "Filter: t1.a IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\ |
| 155 | + \n Subquery: [count(*):Int64]\ |
| 156 | + \n Projection: count(*) [count(*):Int64]\ |
| 157 | + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ |
| 158 | + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ |
| 159 | + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; |
| 160 | + assert_plan_eq(plan, expected) |
| 161 | + } |
| 162 | + |
| 163 | + #[test] |
| 164 | + fn test_count_wildcard_on_where_exists() -> Result<()> { |
| 165 | + let table_scan_t1 = test_table_scan_with_name("t1")?; |
| 166 | + let table_scan_t2 = test_table_scan_with_name("t2")?; |
| 167 | + |
| 168 | + let plan = LogicalPlanBuilder::from(table_scan_t1) |
| 169 | + .filter(exists(Arc::new( |
| 170 | + LogicalPlanBuilder::from(table_scan_t2) |
| 171 | + .aggregate(Vec::<Expr>::new(), vec![count(wildcard())])? |
| 172 | + .project(vec![count(wildcard())])? |
| 173 | + .build()?, |
| 174 | + )))? |
| 175 | + .build()?; |
| 176 | + |
| 177 | + let expected = "Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\ |
| 178 | + \n Subquery: [count(*):Int64]\ |
| 179 | + \n Projection: count(*) [count(*):Int64]\ |
| 180 | + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ |
| 181 | + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ |
| 182 | + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; |
| 183 | + assert_plan_eq(plan, expected) |
| 184 | + } |
| 185 | + |
| 186 | + #[test] |
| 187 | + fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { |
| 188 | + let table_scan_t1 = test_table_scan_with_name("t1")?; |
| 189 | + let table_scan_t2 = test_table_scan_with_name("t2")?; |
| 190 | + |
| 191 | + let plan = LogicalPlanBuilder::from(table_scan_t1) |
| 192 | + .filter( |
| 193 | + scalar_subquery(Arc::new( |
| 194 | + LogicalPlanBuilder::from(table_scan_t2) |
| 195 | + .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? |
| 196 | + .aggregate( |
| 197 | + Vec::<Expr>::new(), |
| 198 | + vec![count(lit(COUNT_STAR_EXPANSION))], |
| 199 | + )? |
| 200 | + .project(vec![count(lit(COUNT_STAR_EXPANSION))])? |
| 201 | + .build()?, |
| 202 | + )) |
| 203 | + .gt(lit(ScalarValue::UInt8(Some(0)))), |
| 204 | + )? |
| 205 | + .project(vec![col("t1.a"), col("t1.b")])? |
| 206 | + .build()?; |
| 207 | + |
| 208 | + let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\ |
| 209 | + \n Filter: (<subquery>) > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\ |
| 210 | + \n Subquery: [count(Int64(1)):Int64]\ |
| 211 | + \n Projection: count(Int64(1)) [count(Int64(1)):Int64]\ |
| 212 | + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64]\ |
| 213 | + \n Filter: outer_ref(t1.a) = t2.a [a:UInt32, b:UInt32, c:UInt32]\ |
| 214 | + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ |
| 215 | + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; |
| 216 | + assert_plan_eq(plan, expected) |
| 217 | + } |
| 218 | + #[test] |
| 219 | + fn test_count_wildcard_on_window() -> Result<()> { |
| 220 | + let table_scan = test_table_scan()?; |
| 221 | + |
| 222 | + let plan = LogicalPlanBuilder::from(table_scan) |
| 223 | + .window(vec![Expr::WindowFunction(expr::WindowFunction::new( |
| 224 | + WindowFunctionDefinition::AggregateUDF(count_udaf()), |
| 225 | + vec![wildcard()], |
| 226 | + vec![], |
| 227 | + vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], |
| 228 | + WindowFrame::new_bounds( |
| 229 | + WindowFrameUnits::Range, |
| 230 | + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), |
| 231 | + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), |
| 232 | + ), |
| 233 | + None, |
| 234 | + ))])? |
| 235 | + .project(vec![count(wildcard())])? |
| 236 | + .build()?; |
| 237 | + |
| 238 | + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ |
| 239 | + \n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ |
| 240 | + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; |
| 241 | + assert_plan_eq(plan, expected) |
| 242 | + } |
| 243 | + |
| 244 | + #[test] |
| 245 | + fn test_count_wildcard_on_aggregate() -> Result<()> { |
| 246 | + let table_scan = test_table_scan()?; |
| 247 | + let plan = LogicalPlanBuilder::from(table_scan) |
| 248 | + .aggregate(Vec::<Expr>::new(), vec![count(wildcard())])? |
| 249 | + .project(vec![count(wildcard())])? |
| 250 | + .build()?; |
| 251 | + |
| 252 | + let expected = "Projection: count(*) [count(*):Int64]\ |
| 253 | + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ |
| 254 | + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; |
| 255 | + assert_plan_eq(plan, expected) |
| 256 | + } |
| 257 | + |
| 258 | + #[test] |
| 259 | + fn test_count_wildcard_on_non_count_aggregate() -> Result<()> { |
| 260 | + let table_scan = test_table_scan()?; |
| 261 | + let res = LogicalPlanBuilder::from(table_scan) |
| 262 | + .aggregate(Vec::<Expr>::new(), vec![sum(wildcard())]); |
| 263 | + assert!(res.is_err()); |
| 264 | + Ok(()) |
| 265 | + } |
| 266 | + |
| 267 | + #[test] |
| 268 | + fn test_count_wildcard_on_nesting() -> Result<()> { |
| 269 | + let table_scan = test_table_scan()?; |
| 270 | + let plan = LogicalPlanBuilder::from(table_scan) |
| 271 | + .aggregate(Vec::<Expr>::new(), vec![max(count(wildcard()))])? |
| 272 | + .project(vec![count(wildcard())])? |
| 273 | + .build()?; |
| 274 | + |
| 275 | + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ |
| 276 | + \n Aggregate: groupBy=[[]], aggr=[[MAX(count(Int64(1))) AS MAX(count(*))]] [MAX(count(*)):Int64;N]\ |
| 277 | + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; |
| 278 | + assert_plan_eq(plan, expected) |
| 279 | + } |
| 280 | +} |
0 commit comments