Skip to content

Commit 5ceacb6

Browse files
authored
Convert dedup pushdown to composite + top_hits (#4844)
* Enable dedup pushdown Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix doctest Signed-off-by: Lantao Jin <ltjin@amazon.com> * refactor Signed-off-by: Lantao Jin <ltjin@amazon.com> * Disable dedup expr Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix IT Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix yaml test Signed-off-by: Lantao Jin <ltjin@amazon.com> * add more comments in code Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix conflicts Signed-off-by: Lantao Jin <ltjin@amazon.com> * Address comments Signed-off-by: Lantao Jin <ltjin@amazon.com> --------- Signed-off-by: Lantao Jin <ltjin@amazon.com>
1 parent afc98dd commit 5ceacb6

74 files changed

Lines changed: 1207 additions & 477 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
1616
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
1717
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_DEDUP;
18+
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_JOIN_MAX_DEDUP;
1819
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_MAIN;
1920
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_RARE_TOP;
2021
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_STREAMSTATS;
@@ -48,9 +49,6 @@
4849
import org.apache.calcite.rel.RelNode;
4950
import org.apache.calcite.rel.core.Aggregate;
5051
import org.apache.calcite.rel.core.JoinRelType;
51-
import org.apache.calcite.rel.hint.HintStrategyTable;
52-
import org.apache.calcite.rel.hint.RelHint;
53-
import org.apache.calcite.rel.logical.LogicalAggregate;
5452
import org.apache.calcite.rel.logical.LogicalValues;
5553
import org.apache.calcite.rel.type.RelDataType;
5654
import org.apache.calcite.rel.type.RelDataTypeFamily;
@@ -1054,7 +1052,7 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10541052
List<String> intendedGroupKeyAliases = getGroupKeyNamesAfterAggregation(reResolved.getLeft());
10551053
context.relBuilder.aggregate(
10561054
context.relBuilder.groupKey(reResolved.getLeft()), reResolved.getRight());
1057-
if (hintBucketNonNull) addIgnoreNullBucketHintToAggregate(context);
1055+
if (hintBucketNonNull) PlanUtils.addIgnoreNullBucketHintToAggregate(context.relBuilder);
10581056
// During aggregation, Calcite projects both input dependencies and output group-by fields.
10591057
// When names conflict, Calcite adds numeric suffixes (e.g., "value0").
10601058
// Apply explicit renaming to restore the intended aliases.
@@ -1316,7 +1314,7 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
13161314
: duplicatedFieldNames.stream()
13171315
.map(a -> (RexNode) context.relBuilder.field(a))
13181316
.toList();
1319-
buildDedupNotNull(context, dedupeFields, allowedDuplication);
1317+
buildDedupNotNull(context, dedupeFields, allowedDuplication, true);
13201318
}
13211319
context.relBuilder.join(
13221320
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
@@ -1372,7 +1370,7 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
13721370
List<RexNode> dedupeFields =
13731371
getRightColumnsInJoinCriteria(context.relBuilder, joinCondition);
13741372

1375-
buildDedupNotNull(context, dedupeFields, allowedDuplication);
1373+
buildDedupNotNull(context, dedupeFields, allowedDuplication, true);
13761374
}
13771375
context.relBuilder.join(
13781376
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
@@ -1537,24 +1535,20 @@ public RelNode visitDedupe(Dedupe node, CalcitePlanContext context) {
15371535
if (keepEmpty) {
15381536
buildDedupOrNull(context, dedupeFields, allowedDuplication);
15391537
} else {
1540-
buildDedupNotNull(context, dedupeFields, allowedDuplication);
1538+
buildDedupNotNull(context, dedupeFields, allowedDuplication, false);
15411539
}
15421540
return context.relBuilder.peek();
15431541
}
15441542

15451543
private static void buildDedupOrNull(
15461544
CalcitePlanContext context, List<RexNode> dedupeFields, Integer allowedDuplication) {
15471545
/*
1548-
* | dedup 2 a, b keepempty=false
1549-
* DropColumns('_row_number_dedup_)
1550-
* +- Filter ('_row_number_dedup_ <= n OR isnull('a) OR isnull('b))
1551-
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
1546+
* | dedup 2 a, b keepempty=true
1547+
* LogicalProject(...)
1548+
* +- LogicalFilter(condition=[OR(IS NULL(a), IS NULL(b), <=(_row_number_dedup_, 1))])
1549+
* +- LogicalProject(..., _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY a, b)])
15521550
* +- ...
15531551
*/
1554-
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
1555-
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a
1556-
// ASC
1557-
// NULLS FIRST, 'b ASC NULLS FIRST]
15581552
RexNode rowNumber =
15591553
context
15601554
.relBuilder
@@ -1577,16 +1571,21 @@ private static void buildDedupOrNull(
15771571
}
15781572

15791573
private static void buildDedupNotNull(
1580-
CalcitePlanContext context, List<RexNode> dedupeFields, Integer allowedDuplication) {
1574+
CalcitePlanContext context,
1575+
List<RexNode> dedupeFields,
1576+
Integer allowedDuplication,
1577+
boolean fromJoinMaxOption) {
15811578
/*
15821579
* | dedup 2 a, b keepempty=false
1583-
* DropColumns('_row_number_dedup_)
1584-
* +- Filter ('_row_number_dedup_ <= n)
1585-
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
1586-
* +- Filter (isnotnull('a) AND isnotnull('b))
1587-
* +- ...
1580+
* LogicalProject(...)
1581+
* +- LogicalFilter(condition=[<=(_row_number_dedup_, n)]))
1582+
* +- LogicalProject(..., _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY a, b)])
1583+
* +- LogicalFilter(condition=[AND(IS NOT NULL(a), IS NOT NULL(b))])
1584+
* +- ...
15881585
*/
15891586
// Filter (isnotnull('a) AND isnotnull('b))
1587+
String rowNumberAlias =
1588+
fromJoinMaxOption ? ROW_NUMBER_COLUMN_FOR_JOIN_MAX_DEDUP : ROW_NUMBER_COLUMN_FOR_DEDUP;
15901589
context.relBuilder.filter(
15911590
context.relBuilder.and(dedupeFields.stream().map(context.relBuilder::isNotNull).toList()));
15921591
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
@@ -1600,15 +1599,15 @@ private static void buildDedupNotNull(
16001599
.partitionBy(dedupeFields)
16011600
.orderBy(dedupeFields)
16021601
.rowsTo(RexWindowBounds.CURRENT_ROW)
1603-
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
1602+
.as(rowNumberAlias);
16041603
context.relBuilder.projectPlus(rowNumber);
1605-
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
1604+
RexNode rowNumberField = context.relBuilder.field(rowNumberAlias);
16061605
// Filter ('_row_number_dedup_ <= n)
16071606
context.relBuilder.filter(
16081607
context.relBuilder.lessThanOrEqual(
1609-
_row_number_dedup_, context.relBuilder.literal(allowedDuplication)));
1608+
rowNumberField, context.relBuilder.literal(allowedDuplication)));
16101609
// DropColumns('_row_number_dedup_)
1611-
context.relBuilder.projectExcept(_row_number_dedup_);
1610+
context.relBuilder.projectExcept(rowNumberField);
16121611
}
16131612

16141613
@Override
@@ -2395,25 +2394,6 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
23952394
return context.relBuilder.peek();
23962395
}
23972396

2398-
private static void addIgnoreNullBucketHintToAggregate(CalcitePlanContext context) {
2399-
final RelHint statHits =
2400-
RelHint.builder("stats_args").hintOption(Argument.BUCKET_NULLABLE, "false").build();
2401-
assert context.relBuilder.peek() instanceof LogicalAggregate
2402-
: "Stats hits should be added to LogicalAggregate";
2403-
context.relBuilder.hints(statHits);
2404-
context
2405-
.relBuilder
2406-
.getCluster()
2407-
.setHintStrategies(
2408-
HintStrategyTable.builder()
2409-
.hintStrategy(
2410-
"stats_args",
2411-
(hint, rel) -> {
2412-
return rel instanceof LogicalAggregate;
2413-
})
2414-
.build());
2415-
}
2416-
24172397
@Override
24182398
public RelNode visitTableFunction(TableFunction node, CalcitePlanContext context) {
24192399
throw new CalciteUnsupportedException("Table function is unsupported in Calcite");

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,21 @@
2525
import org.apache.calcite.rel.RelHomogeneousShuttle;
2626
import org.apache.calcite.rel.RelNode;
2727
import org.apache.calcite.rel.RelShuttle;
28+
import org.apache.calcite.rel.core.AggregateCall;
2829
import org.apache.calcite.rel.core.Project;
2930
import org.apache.calcite.rel.core.Sort;
3031
import org.apache.calcite.rel.core.TableScan;
32+
import org.apache.calcite.rel.hint.HintStrategyTable;
33+
import org.apache.calcite.rel.hint.RelHint;
34+
import org.apache.calcite.rel.logical.LogicalAggregate;
35+
import org.apache.calcite.rel.logical.LogicalFilter;
3136
import org.apache.calcite.rel.logical.LogicalProject;
3237
import org.apache.calcite.rel.logical.LogicalSort;
3338
import org.apache.calcite.rel.type.RelDataType;
3439
import org.apache.calcite.rex.RexCall;
3540
import org.apache.calcite.rex.RexCorrelVariable;
3641
import org.apache.calcite.rex.RexInputRef;
42+
import org.apache.calcite.rex.RexLiteral;
3743
import org.apache.calcite.rex.RexNode;
3844
import org.apache.calcite.rex.RexOver;
3945
import org.apache.calcite.rex.RexVisitorImpl;
@@ -45,8 +51,11 @@
4551
import org.apache.calcite.tools.RelBuilder;
4652
import org.apache.calcite.util.Pair;
4753
import org.apache.calcite.util.Util;
54+
import org.apache.calcite.util.mapping.Mapping;
55+
import org.apache.calcite.util.mapping.Mappings;
4856
import org.opensearch.sql.ast.AbstractNodeVisitor;
4957
import org.opensearch.sql.ast.Node;
58+
import org.opensearch.sql.ast.expression.Argument;
5059
import org.opensearch.sql.ast.expression.IntervalUnit;
5160
import org.opensearch.sql.ast.expression.SpanUnit;
5261
import org.opensearch.sql.ast.expression.WindowBound;
@@ -62,6 +71,7 @@ public interface PlanUtils {
6271
/** this is only for dedup command, do not reuse it in other command */
6372
String ROW_NUMBER_COLUMN_FOR_DEDUP = "_row_number_dedup_";
6473

74+
String ROW_NUMBER_COLUMN_FOR_JOIN_MAX_DEDUP = "_row_number_join_max_dedup_";
6575
String ROW_NUMBER_COLUMN_FOR_RARE_TOP = "_row_number_rare_top_";
6676
String ROW_NUMBER_COLUMN_FOR_MAIN = "_row_number_main_";
6777
String ROW_NUMBER_COLUMN_FOR_SUBSEARCH = "_row_number_subsearch_";
@@ -449,18 +459,15 @@ static RexNode derefMapCall(RexNode rexNode) {
449459
return rexNode;
450460
}
451461

452-
/** Check if contains RexOver introduced by dedup */
453-
static boolean containsRowNumberDedup(LogicalProject project) {
454-
return project.getProjects().stream()
455-
.anyMatch(p -> p instanceof RexOver && p.getKind() == SqlKind.ROW_NUMBER)
456-
&& project.getRowType().getFieldNames().contains(ROW_NUMBER_COLUMN_FOR_DEDUP);
462+
/** Check if contains dedup */
463+
static boolean containsRowNumberDedup(RelNode node) {
464+
return node.getRowType().getFieldNames().stream().anyMatch(ROW_NUMBER_COLUMN_FOR_DEDUP::equals);
457465
}
458466

459-
/** Check if contains RexOver introduced by dedup top/rare */
460-
static boolean containsRowNumberRareTop(LogicalProject project) {
461-
return project.getProjects().stream()
462-
.anyMatch(p -> p instanceof RexOver && p.getKind() == SqlKind.ROW_NUMBER)
463-
&& project.getRowType().getFieldNames().contains(ROW_NUMBER_COLUMN_FOR_RARE_TOP);
467+
/** Check if contains dedup for top/rare */
468+
static boolean containsRowNumberRareTop(RelNode node) {
469+
return node.getRowType().getFieldNames().stream()
470+
.anyMatch(ROW_NUMBER_COLUMN_FOR_RARE_TOP::equals);
464471
}
465472

466473
/** Get all RexWindow list from LogicalProject */
@@ -508,10 +515,6 @@ static boolean distinctProjectList(LogicalProject project) {
508515
return project.getNamedProjects().stream().allMatch(rexSet::add);
509516
}
510517

511-
static boolean containsRexOver(LogicalProject project) {
512-
return project.getProjects().stream().anyMatch(RexOver::containsOver);
513-
}
514-
515518
/**
516519
* The LogicalSort is a LIMIT that should be pushed down when its fetch field is not null and its
517520
* collation is empty. For example: <code>sort name | head 5</code> should not be pushed down
@@ -524,7 +527,7 @@ static boolean isLogicalSortLimit(LogicalSort sort) {
524527
return sort.fetch != null;
525528
}
526529

527-
static boolean projectContainsExpr(Project project) {
530+
static boolean containsRexCall(Project project) {
528531
return project.getProjects().stream().anyMatch(p -> p instanceof RexCall);
529532
}
530533

@@ -595,4 +598,58 @@ static void replaceTop(RelBuilder relBuilder, RelNode relNode) {
595598
throw new IllegalStateException("Unable to invoke RelBuilder.replaceTop", e);
596599
}
597600
}
601+
602+
static void addIgnoreNullBucketHintToAggregate(RelBuilder relBuilder) {
603+
final RelHint statHits =
604+
RelHint.builder("stats_args").hintOption(Argument.BUCKET_NULLABLE, "false").build();
605+
assert relBuilder.peek() instanceof LogicalAggregate
606+
: "Stats hits should be added to LogicalAggregate";
607+
relBuilder.hints(statHits);
608+
relBuilder
609+
.getCluster()
610+
.setHintStrategies(
611+
HintStrategyTable.builder()
612+
.hintStrategy(
613+
"stats_args",
614+
(hint, rel) -> {
615+
return rel instanceof LogicalAggregate;
616+
})
617+
.build());
618+
}
619+
620+
/** Extract the RexLiteral from the aggregate call if the aggregate call is a LITERAL_AGG. */
621+
static @Nullable RexLiteral getObjectFromLiteralAgg(AggregateCall aggCall) {
622+
if (aggCall.getAggregation().kind == SqlKind.LITERAL_AGG) {
623+
return (RexLiteral)
624+
aggCall.rexList.stream().filter(rex -> rex instanceof RexLiteral).findAny().orElse(null);
625+
} else {
626+
return null;
627+
}
628+
}
629+
630+
/**
631+
* This is a helper method to create a target mapping easily for replacing calling {@link
632+
* Mappings#target(List, int)}
633+
*
634+
* @param rexNodes the rex list in schema
635+
* @param schema the schema which contains the rex list
636+
* @return the target mapping
637+
*/
638+
static Mapping mapping(List<RexNode> rexNodes, RelDataType schema) {
639+
return Mappings.target(getSelectColumns(rexNodes), schema.getFieldCount());
640+
}
641+
642+
static boolean mayBeFilterFromBucketNonNull(LogicalFilter filter) {
643+
RexNode condition = filter.getCondition();
644+
return isNotNullOnRef(condition)
645+
|| (condition instanceof RexCall rexCall
646+
&& rexCall.getOperator().equals(SqlStdOperatorTable.AND)
647+
&& rexCall.getOperands().stream().allMatch(PlanUtils::isNotNullOnRef));
648+
}
649+
650+
private static boolean isNotNullOnRef(RexNode rex) {
651+
return rex instanceof RexCall rexCall
652+
&& rexCall.isA(SqlKind.IS_NOT_NULL)
653+
&& rexCall.getOperands().get(0) instanceof RexInputRef;
654+
}
598655
}

core/src/main/java/org/opensearch/sql/data/type/ExprType.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,18 @@ default Optional<String> getOriginalPath() {
6262
}
6363

6464
/**
65-
* Get the original path. Types like alias type should be derived from the type of the original
66-
* field.
65+
* Get the original expr path. Types like alias type should be derived from the type of the
66+
* original field.
6767
*/
6868
default ExprType getOriginalExprType() {
6969
return this;
7070
}
71+
72+
/**
73+
* Get the original data type. Types like alias type should be derived from the type of the
74+
* original field.
75+
*/
76+
default ExprType getOriginalType() {
77+
return this;
78+
}
7179
}

0 commit comments

Comments
 (0)