Skip to content

Commit 3d2786d

Browse files
committed
address comments
Signed-off-by: Lantao Jin <ltjin@amazon.com>
1 parent 8d31a38 commit 3d2786d

6 files changed

Lines changed: 196 additions & 9 deletions

File tree

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.utils;
7+
8+
import java.util.Iterator;
9+
import java.util.LinkedList;
10+
import java.util.List;
11+
import org.apache.commons.lang3.tuple.Pair;
12+
13+
public interface Utils {
14+
static <I> List<Pair<I, Integer>> zipWithIndex(List<I> input) {
15+
LinkedList<Pair<I, Integer>> result = new LinkedList<>();
16+
Iterator<I> iter = input.iterator();
17+
int index = 0;
18+
while (iter.hasNext()) {
19+
result.add(Pair.of(iter.next(), index++));
20+
}
21+
return result;
22+
}
23+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
calcite:
2+
logical: |
3+
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
4+
LogicalProject(count_area=[$1], min_area=[$2], max_area=[$3], avg_area=[$4], avg_age=[$5], name=[$0])
5+
LogicalAggregate(group=[{0}], count_area=[COUNT($1)], min_area=[MIN($1)], max_area=[MAX($1)], avg_area=[AVG($1)], avg_age=[AVG($2)])
6+
LogicalProject(name=[$1], address.area=[$3], age=[$9])
7+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_nested_simple]])
8+
physical: |
9+
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_nested_simple]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count_area=COUNT($1),min_area=MIN($1),max_area=MAX($1),avg_area=AVG($1),avg_age=AVG($2)), PROJECT->[count_area, min_area, max_area, avg_area, avg_age, name], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"name":{"terms":{"field":"name.keyword","missing_bucket":true,"missing_order":"first","order":"asc"}}}]},"aggregations":{"nested_count_area":{"nested":{"path":"address"},"aggregations":{"count_area":{"value_count":{"field":"address.area"}}}},"nested_min_area":{"nested":{"path":"address"},"aggregations":{"min_area":{"min":{"field":"address.area"}}}},"nested_max_area":{"nested":{"path":"address"},"aggregations":{"max_area":{"max":{"field":"address.area"}}}},"nested_avg_area":{"nested":{"path":"address"},"aggregations":{"avg_area":{"avg":{"field":"address.area"}}}},"avg_age":{"avg":{"field":"age"}}}}}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.opensearch.planner.physical;
7+
8+
import java.util.List;
9+
import org.apache.calcite.adapter.enumerable.AggImplementor;
10+
import org.apache.calcite.adapter.enumerable.EnumerableAggregateBase;
11+
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
12+
import org.apache.calcite.adapter.enumerable.EnumerableRel;
13+
import org.apache.calcite.adapter.enumerable.EnumerableRelImplementor;
14+
import org.apache.calcite.adapter.enumerable.RexImpTable;
15+
import org.apache.calcite.plan.RelOptCluster;
16+
import org.apache.calcite.plan.RelOptCost;
17+
import org.apache.calcite.plan.RelOptPlanner;
18+
import org.apache.calcite.plan.RelTraitSet;
19+
import org.apache.calcite.rel.InvalidRelException;
20+
import org.apache.calcite.rel.RelNode;
21+
import org.apache.calcite.rel.core.AggregateCall;
22+
import org.apache.calcite.rel.hint.RelHint;
23+
import org.apache.calcite.rel.metadata.RelMetadataQuery;
24+
import org.apache.calcite.util.ImmutableBitSet;
25+
import org.checkerframework.checker.nullness.qual.Nullable;
26+
27+
/**
28+
* The enumerable aggregate physical implementation for OpenSearch nested aggregation.
29+
* https://docs.opensearch.org/latest/aggregations/bucket/nested/
30+
*/
31+
public class EnumerableNestedAggregate extends EnumerableAggregateBase implements EnumerableRel {
32+
33+
public EnumerableNestedAggregate(
34+
RelOptCluster cluster,
35+
RelTraitSet traitSet,
36+
List<RelHint> hints,
37+
RelNode input,
38+
ImmutableBitSet groupSet,
39+
@Nullable List<ImmutableBitSet> groupSets,
40+
List<AggregateCall> aggCalls)
41+
throws InvalidRelException {
42+
super(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls);
43+
assert getConvention() instanceof EnumerableConvention;
44+
45+
for (AggregateCall aggCall : aggCalls) {
46+
if (aggCall.isDistinct()) {
47+
throw new InvalidRelException("distinct aggregation not supported");
48+
}
49+
if (aggCall.distinctKeys != null) {
50+
throw new InvalidRelException("within-distinct aggregation not supported");
51+
}
52+
AggImplementor implementor2 = RexImpTable.INSTANCE.get(aggCall.getAggregation(), false);
53+
if (implementor2 == null) {
54+
throw new InvalidRelException("aggregation " + aggCall.getAggregation() + " not supported");
55+
}
56+
}
57+
}
58+
59+
@Override
60+
public EnumerableNestedAggregate copy(
61+
RelTraitSet traitSet,
62+
RelNode input,
63+
ImmutableBitSet groupSet,
64+
@Nullable List<ImmutableBitSet> groupSets,
65+
List<AggregateCall> aggCalls) {
66+
try {
67+
return new EnumerableNestedAggregate(
68+
getCluster(), traitSet, getHints(), input, groupSet, groupSets, aggCalls);
69+
} catch (InvalidRelException e) {
70+
// Semantic error not possible. Must be a bug. Convert to
71+
// internal error.
72+
throw new AssertionError(e);
73+
}
74+
}
75+
76+
@Override
77+
public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
78+
throw new RuntimeException();
79+
}
80+
81+
@Override
82+
public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
83+
return super.computeSelfCost(planner, mq).multiplyBy(.9);
84+
}
85+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.opensearch.planner.rules;
7+
8+
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
9+
import org.apache.calcite.plan.Convention;
10+
import org.apache.calcite.plan.RelTraitSet;
11+
import org.apache.calcite.rel.InvalidRelException;
12+
import org.apache.calcite.rel.RelNode;
13+
import org.apache.calcite.rel.convert.ConverterRule;
14+
import org.apache.calcite.rel.core.Aggregate;
15+
import org.apache.calcite.rel.logical.LogicalAggregate;
16+
import org.apache.logging.log4j.LogManager;
17+
import org.apache.logging.log4j.Logger;
18+
import org.checkerframework.checker.nullness.qual.Nullable;
19+
import org.opensearch.sql.opensearch.planner.physical.EnumerableNestedAggregate;
20+
21+
/** Rule to convert {@link LogicalAggregate} to {@link EnumerableNestedAggregate}. */
22+
public class EnumerableNestedAggregateRule extends ConverterRule {
23+
private static final Logger LOG = LogManager.getLogger();
24+
25+
/** Default configuration. */
26+
public static final Config DEFAULT_CONFIG =
27+
Config.INSTANCE
28+
.withConversion(
29+
LogicalAggregate.class,
30+
Convention.NONE,
31+
EnumerableConvention.INSTANCE,
32+
"EnumerableNestedAggregateRule")
33+
.withRuleFactory(EnumerableNestedAggregateRule::new);
34+
35+
/** Called from the Config. */
36+
protected EnumerableNestedAggregateRule(Config config) {
37+
super(config);
38+
}
39+
40+
@Override
41+
public @Nullable RelNode convert(RelNode rel) {
42+
final Aggregate agg = (Aggregate) rel;
43+
if (agg.getHints().stream()
44+
.noneMatch(
45+
hint ->
46+
hint.hintName.equals("stats_args")
47+
&& hint.kvOptions.values().stream().anyMatch(v -> v.equals("true")))) {
48+
return null;
49+
}
50+
final RelTraitSet traitSet = rel.getCluster().traitSet().replace(EnumerableConvention.INSTANCE);
51+
try {
52+
return new EnumerableNestedAggregate(
53+
rel.getCluster(),
54+
traitSet,
55+
agg.getHints(),
56+
convert(agg.getInput(), traitSet),
57+
agg.getGroupSet(),
58+
agg.getGroupSets(),
59+
agg.getAggCallList());
60+
} catch (InvalidRelException e) {
61+
if (LOG.isDebugEnabled()) {
62+
LOG.debug(e.toString());
63+
}
64+
return null;
65+
}
66+
}
67+
}

opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,7 @@ private static Pair<Builder, List<MetricParser>> processAggregateCalls(
337337

338338
for (int i = 0; i < aggCalls.size(); i++) {
339339
AggregateCall aggCall = aggCalls.get(i);
340-
List<org.apache.calcite.util.Pair<RexNode, String>> args =
341-
convertAggArgThroughProject(aggCall, project);
340+
List<Pair<RexNode, String>> args = convertAggArgThroughProject(aggCall, project);
342341
String aggFieldName = aggFieldNames.get(i);
343342

344343
Pair<AggregationBuilder, MetricParser> builderAndParser =
@@ -359,20 +358,24 @@ private static Pair<Builder, List<MetricParser>> processAggregateCalls(
359358
* @param project the project
360359
* @return the converted Pair<RexNode, String> list
361360
*/
362-
private static List<org.apache.calcite.util.Pair<RexNode, String>> convertAggArgThroughProject(
361+
private static List<Pair<RexNode, String>> convertAggArgThroughProject(
363362
AggregateCall aggCall, Project project) {
364363
return project == null
365364
? List.of()
366365
: PlanUtils.getObjectFromLiteralAgg(aggCall) != null
367366
? project.getNamedProjects().stream()
368367
.filter(rex -> !rex.getKey().isA(SqlKind.ROW_NUMBER))
368+
.map(p -> Pair.of(p.getKey(), p.getValue()))
369369
.toList()
370-
: aggCall.getArgList().stream().map(project.getNamedProjects()::get).toList();
370+
: aggCall.getArgList().stream()
371+
.map(project.getNamedProjects()::get)
372+
.map(p -> Pair.of(p.getKey(), p.getValue()))
373+
.toList();
371374
}
372375

373376
private static Pair<AggregationBuilder, MetricParser> createAggregationBuilderAndParser(
374377
AggregateCall aggCall,
375-
List<org.apache.calcite.util.Pair<RexNode, String>> args,
378+
List<Pair<RexNode, String>> args,
376379
String aggFieldName,
377380
AggregateAnalyzer.AggregateBuilderHelper helper) {
378381
if (aggCall.isDistinct()) {
@@ -384,7 +387,7 @@ private static Pair<AggregationBuilder, MetricParser> createAggregationBuilderAn
384387

385388
private static Pair<AggregationBuilder, MetricParser> createDistinctAggregation(
386389
AggregateCall aggCall,
387-
List<org.apache.calcite.util.Pair<RexNode, String>> args,
390+
List<Pair<RexNode, String>> args,
388391
String aggFieldName,
389392
AggregateBuilderHelper helper) {
390393

@@ -403,7 +406,7 @@ private static Pair<AggregationBuilder, MetricParser> createDistinctAggregation(
403406

404407
private static Pair<AggregationBuilder, MetricParser> createRegularAggregation(
405408
AggregateCall aggCall,
406-
List<org.apache.calcite.util.Pair<RexNode, String>> args,
409+
List<Pair<RexNode, String>> args,
407410
String aggFieldName,
408411
AggregateBuilderHelper helper) {
409412

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/TopHitsParser.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ public List<Map<String, Object>> parse(Aggregation agg) {
9090
// LinkedHashMap["name" -> "A", "category" -> "Y"]
9191
// ]
9292
return Arrays.stream(hits)
93-
.<Map<String, Object>>map(
93+
.map(
9494
hit -> {
95-
Map map = new LinkedHashMap<>(hit.getSourceAsMap());
95+
Map<String, Object> map = new LinkedHashMap<>(hit.getSourceAsMap());
9696
hit.getFields().values().forEach(f -> map.put(f.getName(), f.getValue()));
9797
return map;
9898
})

0 commit comments

Comments
 (0)