@@ -19,14 +19,29 @@ package org.apache.spark.sql.execution.aggregate
1919
2020import org .apache .spark .sql .catalyst .expressions ._
2121import org .apache .spark .sql .catalyst .expressions .aggregate ._
22- import org .apache .spark .sql .catalyst .optimizer .NormalizeFloatingNumbers
2322import org .apache .spark .sql .execution .SparkPlan
2423import org .apache .spark .sql .execution .streaming .{StateStoreRestoreExec , StateStoreSaveExec }
2524
2625/**
2726 * Utility functions used by the query planner to convert our plan to new aggregation code path.
2827 */
2928object AggUtils {
29+
30+ private def mayRemoveAggFilters (exprs : Seq [AggregateExpression ]): Seq [AggregateExpression ] = {
31+ exprs.map { ae =>
32+ if (ae.filter.isDefined) {
33+ ae.mode match {
34+ // Aggregate filters are applicable only in partial/complete modes;
35+ // this method filters out them, otherwise.
36+ case Partial | Complete => ae
37+ case _ => ae.copy(filter = None )
38+ }
39+ } else {
40+ ae
41+ }
42+ }
43+ }
44+
3045 private def createAggregate (
3146 requiredChildDistributionExpressions : Option [Seq [Expression ]] = None ,
3247 groupingExpressions : Seq [NamedExpression ] = Nil ,
@@ -41,7 +56,7 @@ object AggUtils {
4156 HashAggregateExec (
4257 requiredChildDistributionExpressions = requiredChildDistributionExpressions,
4358 groupingExpressions = groupingExpressions,
44- aggregateExpressions = aggregateExpressions,
59+ aggregateExpressions = mayRemoveAggFilters( aggregateExpressions) ,
4560 aggregateAttributes = aggregateAttributes,
4661 initialInputBufferOffset = initialInputBufferOffset,
4762 resultExpressions = resultExpressions,
@@ -54,7 +69,7 @@ object AggUtils {
5469 ObjectHashAggregateExec (
5570 requiredChildDistributionExpressions = requiredChildDistributionExpressions,
5671 groupingExpressions = groupingExpressions,
57- aggregateExpressions = aggregateExpressions,
72+ aggregateExpressions = mayRemoveAggFilters( aggregateExpressions) ,
5873 aggregateAttributes = aggregateAttributes,
5974 initialInputBufferOffset = initialInputBufferOffset,
6075 resultExpressions = resultExpressions,
@@ -63,7 +78,7 @@ object AggUtils {
6378 SortAggregateExec (
6479 requiredChildDistributionExpressions = requiredChildDistributionExpressions,
6580 groupingExpressions = groupingExpressions,
66- aggregateExpressions = aggregateExpressions,
81+ aggregateExpressions = mayRemoveAggFilters( aggregateExpressions) ,
6782 aggregateAttributes = aggregateAttributes,
6883 initialInputBufferOffset = initialInputBufferOffset,
6984 resultExpressions = resultExpressions,
0 commit comments