Skip to content

Commit 77633ef

Browse files
authored
Support nested aggregation when calcite enabled (#4979)
* refactor: throw exception if pushdown cannot be applied Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix tests Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix IT Signed-off-by: Lantao Jin <ltjin@amazon.com> * Support top/dedup/aggregate by nested sub-fields Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix typo Signed-off-by: Lantao Jin <ltjin@amazon.com> * address comments Signed-off-by: Lantao Jin <ltjin@amazon.com> * minor fixing Signed-off-by: Lantao Jin <ltjin@amazon.com> --------- Signed-off-by: Lantao Jin <ltjin@amazon.com>
1 parent 08be6f9 commit 77633ef

43 files changed

Lines changed: 1181 additions & 215 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
152152
import org.opensearch.sql.calcite.utils.BinUtils;
153153
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
154+
import org.opensearch.sql.calcite.utils.PPLHintUtils;
154155
import org.opensearch.sql.calcite.utils.PlanUtils;
155156
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
156157
import org.opensearch.sql.calcite.utils.WildcardUtils;
@@ -949,13 +950,14 @@ private boolean isCountField(RexCall call) {
949950
* @param groupExprList group by expression list
950951
* @param aggExprList aggregate expression list
951952
* @param context CalcitePlanContext
953+
* @param hintIgnoreNullBucket true if bucket_nullable=false
952954
* @return Pair of (group-by list, field list, aggregate list)
953955
*/
954956
private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
955957
List<UnresolvedExpression> groupExprList,
956958
List<UnresolvedExpression> aggExprList,
957959
CalcitePlanContext context,
958-
boolean hintBucketNonNull) {
960+
boolean hintIgnoreNullBucket) {
959961
Pair<List<RexNode>, List<AggCall>> resolved =
960962
resolveAttributesForAggregation(groupExprList, aggExprList, context);
961963
List<RexNode> resolvedGroupByList = resolved.getLeft();
@@ -1047,7 +1049,9 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10471049
// \- Scan t
10481050
List<RexInputRef> trimmedRefs = new ArrayList<>();
10491051
trimmedRefs.addAll(PlanUtils.getInputRefs(resolvedGroupByList)); // group-by keys first
1050-
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolvedAggCallList));
1052+
List<RexInputRef> aggCallRefs = PlanUtils.getInputRefsFromAggCall(resolvedAggCallList);
1053+
boolean hintNestedAgg = containsNestedAggregator(context.relBuilder, aggCallRefs);
1054+
trimmedRefs.addAll(aggCallRefs);
10511055
context.relBuilder.project(trimmedRefs);
10521056

10531057
// Re-resolve all attributes based on adding trimmed Project.
@@ -1059,7 +1063,8 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10591063
List<String> intendedGroupKeyAliases = getGroupKeyNamesAfterAggregation(reResolved.getLeft());
10601064
context.relBuilder.aggregate(
10611065
context.relBuilder.groupKey(reResolved.getLeft()), reResolved.getRight());
1062-
if (hintBucketNonNull) PlanUtils.addIgnoreNullBucketHintToAggregate(context.relBuilder);
1066+
if (hintIgnoreNullBucket) PPLHintUtils.addIgnoreNullBucketHintToAggregate(context.relBuilder);
1067+
if (hintNestedAgg) PPLHintUtils.addNestedAggCallHintToAggregate(context.relBuilder);
10631068
// During aggregation, Calcite projects both input dependencies and output group-by fields.
10641069
// When names conflict, Calcite adds numeric suffixes (e.g., "value0").
10651070
// Apply explicit renaming to restore the intended aliases.
@@ -1068,6 +1073,17 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10681073
return Pair.of(reResolved.getLeft(), reResolved.getRight());
10691074
}
10701075

1076+
/**
1077+
* Return true if the aggCalls contains a nested field. For example: aggCalls: [count(),
1078+
* count(a.b)] returns true.
1079+
*/
1080+
private boolean containsNestedAggregator(RelBuilder relBuilder, List<RexInputRef> aggCallRefs) {
1081+
return aggCallRefs.stream()
1082+
.map(r -> relBuilder.peek().getRowType().getFieldNames().get(r.getIndex()))
1083+
.map(name -> org.apache.commons.lang3.StringUtils.substringBefore(name, "."))
1084+
.anyMatch(root -> relBuilder.field(root).getType().getSqlTypeName() == SqlTypeName.ARRAY);
1085+
}
1086+
10711087
/**
10721088
* Imitates {@code Registrar.registerExpression} of {@link RelBuilder} to derive the output order
10731089
* of group-by keys after aggregation.
@@ -1173,8 +1189,8 @@ private void visitAggregation(
11731189
}
11741190
groupExprList.addAll(node.getGroupExprList());
11751191

1176-
// Add stats hint to LogicalAggregation.
1177-
boolean toAddHintsOnAggregate =
1192+
// Add a hint to LogicalAggregation when bucket_nullable=false.
1193+
boolean hintIgnoreNullBucket =
11781194
!groupExprList.isEmpty()
11791195
// This checks if all group-bys should be nonnull
11801196
&& nonNullGroupMask.nextClearBit(0) >= groupExprList.size();
@@ -1194,14 +1210,16 @@ private void visitAggregation(
11941210
.filter(nonNullGroupMask::get)
11951211
.mapToObj(nonNullCandidates::get)
11961212
.toList();
1197-
context.relBuilder.filter(
1198-
PlanUtils.getSelectColumns(nonNullFields).stream()
1199-
.map(context.relBuilder::field)
1200-
.map(context.relBuilder::isNotNull)
1201-
.toList());
1213+
if (!nonNullFields.isEmpty()) {
1214+
context.relBuilder.filter(
1215+
PlanUtils.getSelectColumns(nonNullFields).stream()
1216+
.map(context.relBuilder::field)
1217+
.map(context.relBuilder::isNotNull)
1218+
.toList());
1219+
}
12021220

12031221
Pair<List<RexNode>, List<AggCall>> aggregationAttributes =
1204-
aggregateWithTrimming(groupExprList, aggExprList, context, toAddHintsOnAggregate);
1222+
aggregateWithTrimming(groupExprList, aggExprList, context, hintIgnoreNullBucket);
12051223

12061224
// schema reordering
12071225
List<RexNode> outputFields = context.relBuilder.fields();
@@ -2329,9 +2347,9 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
23292347

23302348
// if usenull=false, add a isNotNull before Aggregate and the hint to this Aggregate
23312349
Boolean bucketNullable = (Boolean) argumentMap.get(RareTopN.Option.useNull.name()).getValue();
2332-
boolean toAddHintsOnAggregate = false;
2350+
boolean hintIgnoreNullBucket = false;
23332351
if (!bucketNullable && !groupExprList.isEmpty()) {
2334-
toAddHintsOnAggregate = true;
2352+
hintIgnoreNullBucket = true;
23352353
// add isNotNull filter before aggregation to filter out null bucket
23362354
List<RexNode> groupByList =
23372355
groupExprList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList();
@@ -2341,7 +2359,7 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
23412359
.map(context.relBuilder::isNotNull)
23422360
.toList());
23432361
}
2344-
aggregateWithTrimming(groupExprList, aggExprList, context, toAddHintsOnAggregate);
2362+
aggregateWithTrimming(groupExprList, aggExprList, context, hintIgnoreNullBucket);
23452363

23462364
// 2. add count() column with sort direction
23472365
List<RexNode> partitionKeys = rexVisitor.analyze(node.getGroupExprList(), context);

core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
import org.apache.calcite.jdbc.CalcitePrepare;
5353
import org.apache.calcite.jdbc.CalciteSchema;
5454
import org.apache.calcite.jdbc.Driver;
55-
import org.apache.calcite.linq4j.function.Function0;
5655
import org.apache.calcite.plan.Context;
5756
import org.apache.calcite.plan.Contexts;
5857
import org.apache.calcite.plan.Convention;
@@ -175,8 +174,11 @@ public Connection connect(
175174
}
176175

177176
@Override
178-
protected Function0<CalcitePrepare> createPrepareFactory() {
179-
return OpenSearchPrepareImpl::new;
177+
public CalcitePrepare createPrepare() {
178+
if (prepareFactory != null) {
179+
return prepareFactory.get();
180+
}
181+
return new OpenSearchPrepareImpl();
180182
}
181183
}
182184

@@ -298,10 +300,10 @@ public OpenSearchCalcitePreparingStmt(
298300

299301
@Override
300302
protected PreparedResult implement(RelRoot root) {
301-
Hook.PLAN_BEFORE_IMPLEMENTATION.run(root);
302-
RelDataType resultType = root.rel.getRowType();
303-
boolean isDml = root.kind.belongsTo(SqlKind.DML);
304303
if (root.rel instanceof Scannable scannable) {
304+
Hook.PLAN_BEFORE_IMPLEMENTATION.run(root);
305+
RelDataType resultType = root.rel.getRowType();
306+
boolean isDml = root.kind.belongsTo(SqlKind.DML);
305307
final Bindable bindable = dataContext -> scannable.scan();
306308

307309
return new PreparedResultImpl(

core/src/main/java/org/opensearch/sql/calcite/utils/PPLHintStrategyTable.java

Lines changed: 0 additions & 33 deletions
This file was deleted.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.utils;
7+
8+
import com.google.common.base.Suppliers;
9+
import java.util.function.Supplier;
10+
import lombok.experimental.UtilityClass;
11+
import org.apache.calcite.rel.core.Aggregate;
12+
import org.apache.calcite.rel.hint.HintStrategyTable;
13+
import org.apache.calcite.rel.hint.RelHint;
14+
import org.apache.calcite.rel.logical.LogicalAggregate;
15+
import org.apache.calcite.tools.RelBuilder;
16+
17+
@UtilityClass
18+
public class PPLHintUtils {
19+
private static final String HINT_AGG_ARGUMENTS = "AGG_ARGS";
20+
private static final String KEY_IGNORE_NULL_BUCKET = "ignoreNullBucket";
21+
private static final String KEY_HAS_NESTED_AGG_CALL = "hasNestedAggCall";
22+
23+
private static final Supplier<HintStrategyTable> HINT_STRATEGY_TABLE =
24+
Suppliers.memoize(
25+
() ->
26+
HintStrategyTable.builder()
27+
.hintStrategy(
28+
HINT_AGG_ARGUMENTS,
29+
(hint, rel) -> {
30+
return rel instanceof LogicalAggregate;
31+
})
32+
// add more here
33+
.build());
34+
35+
/**
36+
* Add hint to aggregate to indicate that the aggregate will ignore null value bucket. Notice, the
37+
* current peek of relBuilder is expected to be LogicalAggregate.
38+
*/
39+
public static void addIgnoreNullBucketHintToAggregate(RelBuilder relBuilder) {
40+
assert relBuilder.peek() instanceof LogicalAggregate
41+
: "Hint HINT_AGG_ARGUMENTS can be added to LogicalAggregate only";
42+
final RelHint statHint =
43+
RelHint.builder(HINT_AGG_ARGUMENTS).hintOption(KEY_IGNORE_NULL_BUCKET, "true").build();
44+
relBuilder.hints(statHint);
45+
if (relBuilder.getCluster().getHintStrategies() == HintStrategyTable.EMPTY) {
46+
relBuilder.getCluster().setHintStrategies(HINT_STRATEGY_TABLE.get());
47+
}
48+
}
49+
50+
/**
51+
* Add hint to aggregate to indicate that the aggregate has nested agg call. Notice, the current
52+
* peek of relBuilder is expected to be LogicalAggregate.
53+
*/
54+
public static void addNestedAggCallHintToAggregate(RelBuilder relBuilder) {
55+
assert relBuilder.peek() instanceof LogicalAggregate
56+
: "Hint HINT_AGG_ARGUMENTS can be added to LogicalAggregate only";
57+
final RelHint statHint =
58+
RelHint.builder(HINT_AGG_ARGUMENTS).hintOption(KEY_HAS_NESTED_AGG_CALL, "true").build();
59+
relBuilder.hints(statHint);
60+
if (relBuilder.getCluster().getHintStrategies() == HintStrategyTable.EMPTY) {
61+
relBuilder.getCluster().setHintStrategies(HINT_STRATEGY_TABLE.get());
62+
}
63+
}
64+
65+
/** Return true if the aggregate will ignore null value bucket. */
66+
public static boolean ignoreNullBucket(Aggregate aggregate) {
67+
return aggregate.getHints().stream()
68+
.anyMatch(
69+
hint ->
70+
hint.hintName.equals(PPLHintUtils.HINT_AGG_ARGUMENTS)
71+
&& hint.kvOptions.getOrDefault(KEY_IGNORE_NULL_BUCKET, "false").equals("true"));
72+
}
73+
74+
/** Return true if the aggregate has any nested agg call. */
75+
public static boolean hasNestedAggCall(Aggregate aggregate) {
76+
return aggregate.getHints().stream()
77+
.anyMatch(
78+
hint ->
79+
hint.hintName.equals(PPLHintUtils.HINT_AGG_ARGUMENTS)
80+
&& hint.kvOptions
81+
.getOrDefault(KEY_HAS_NESTED_AGG_CALL, "false")
82+
.equals("true"));
83+
}
84+
}

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
import org.apache.calcite.rel.core.Project;
3737
import org.apache.calcite.rel.core.Sort;
3838
import org.apache.calcite.rel.core.TableScan;
39-
import org.apache.calcite.rel.hint.RelHint;
40-
import org.apache.calcite.rel.logical.LogicalAggregate;
4139
import org.apache.calcite.rel.logical.LogicalFilter;
4240
import org.apache.calcite.rel.logical.LogicalProject;
4341
import org.apache.calcite.rel.logical.LogicalSort;
@@ -62,7 +60,6 @@
6260
import org.apache.calcite.util.mapping.Mappings;
6361
import org.opensearch.sql.ast.AbstractNodeVisitor;
6462
import org.opensearch.sql.ast.Node;
65-
import org.opensearch.sql.ast.expression.Argument;
6663
import org.opensearch.sql.ast.expression.IntervalUnit;
6764
import org.opensearch.sql.ast.expression.SpanUnit;
6865
import org.opensearch.sql.ast.expression.WindowBound;
@@ -610,15 +607,6 @@ static void replaceTop(RelBuilder relBuilder, RelNode relNode) {
610607
}
611608
}
612609

613-
static void addIgnoreNullBucketHintToAggregate(RelBuilder relBuilder) {
614-
final RelHint statHits =
615-
RelHint.builder("stats_args").hintOption(Argument.BUCKET_NULLABLE, "false").build();
616-
assert relBuilder.peek() instanceof LogicalAggregate
617-
: "Stats hits should be added to LogicalAggregate";
618-
relBuilder.hints(statHits);
619-
relBuilder.getCluster().setHintStrategies(PPLHintStrategyTable.getHintStrategyTable());
620-
}
621-
622610
/** Extract the RexLiteral from the aggregate call if the aggregate call is a LITERAL_AGG. */
623611
static @Nullable RexLiteral getObjectFromLiteralAgg(AggregateCall aggCall) {
624612
if (aggCall.getAggregation().kind == SqlKind.LITERAL_AGG) {
@@ -655,13 +643,7 @@ private static boolean isNotNullOnRef(RexNode rex) {
655643
&& rexCall.getOperands().get(0) instanceof RexInputRef;
656644
}
657645

658-
Predicate<Aggregate> aggIgnoreNullBucket =
659-
agg ->
660-
agg.getHints().stream()
661-
.anyMatch(
662-
hint ->
663-
hint.hintName.equals("stats_args")
664-
&& hint.kvOptions.get(Argument.BUCKET_NULLABLE).equals("false"));
646+
Predicate<Aggregate> aggIgnoreNullBucket = PPLHintUtils::ignoreNullBucket;
665647

666648
Predicate<Aggregate> maybeTimeSpanAgg =
667649
agg ->
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.utils;
7+
8+
import java.util.Iterator;
9+
import java.util.LinkedList;
10+
import java.util.List;
11+
import java.util.Map;
12+
import javax.annotation.Nullable;
13+
import org.apache.commons.lang3.StringUtils;
14+
import org.apache.commons.lang3.tuple.Pair;
15+
import org.opensearch.sql.data.type.ExprCoreType;
16+
import org.opensearch.sql.data.type.ExprType;
17+
18+
public interface Utils {
19+
static <I> List<Pair<I, Integer>> zipWithIndex(List<I> input) {
20+
LinkedList<Pair<I, Integer>> result = new LinkedList<>();
21+
Iterator<I> iter = input.iterator();
22+
int index = 0;
23+
while (iter.hasNext()) {
24+
result.add(Pair.of(iter.next(), index++));
25+
}
26+
return result;
27+
}
28+
29+
/**
30+
* Resolve the nested path from the field name.
31+
*
32+
* @param path the field name
33+
* @param fieldTypes the field types
34+
* @return the nested path if exists, otherwise null
35+
*/
36+
static @Nullable String resolveNestedPath(String path, Map<String, ExprType> fieldTypes) {
37+
if (path == null || fieldTypes == null || fieldTypes.isEmpty()) {
38+
return null;
39+
}
40+
boolean found = false;
41+
String current = path;
42+
String parent = StringUtils.substringBeforeLast(current, ".");
43+
while (parent != null && !parent.equals(current)) {
44+
ExprType pathType = fieldTypes.get(parent);
45+
// Nested is mapped to ExprCoreType.ARRAY
46+
if (pathType == ExprCoreType.ARRAY) {
47+
found = true;
48+
break;
49+
}
50+
current = parent;
51+
parent = StringUtils.substringBeforeLast(current, ".");
52+
}
53+
if (found) {
54+
return parent;
55+
} else {
56+
return null;
57+
}
58+
}
59+
}

0 commit comments

Comments
 (0)