Skip to content

Commit d758163

Browse files
authored
Implement type checking for aggregation functions with Calcite (#4024)
* Remove getTypeChecker from FunctionImp interface Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Refactor registerExternalFunction to registerExternalOperator Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Do not register GEOIP function if got incompatible client Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Create scaffold for type checking of aggregation functions Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Add type checkers for aggregation functions Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Test type checking for aggregation functions Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> --------- Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 964d8b5 commit d758163

5 files changed

Lines changed: 387 additions & 307 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/udf/udaf/TakeAggFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public Object result(TakeAccumulator accumulator) {
2424
@Override
2525
public TakeAccumulator add(TakeAccumulator acc, Object... values) {
2626
Object candidateValue = values[0];
27-
int size = 0;
27+
int size;
2828
if (values.length > 1) {
2929
size = (int) values[1];
3030
} else {

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.calcite.rex.RexCall;
3131
import org.apache.calcite.rex.RexNode;
3232
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
33+
import org.apache.calcite.sql.SqlAggFunction;
3334
import org.apache.calcite.sql.SqlIdentifier;
3435
import org.apache.calcite.sql.SqlKind;
3536
import org.apache.calcite.sql.parser.SqlParserPos;
@@ -77,27 +78,71 @@ public class UserDefinedFunctionUtils {
7778
public static Set<String> MULTI_FIELDS_RELEVANCE_FUNCTION_SET =
7879
ImmutableSet.of("simple_query_string", "query_string", "multi_match");
7980

80-
public static RelBuilder.AggCall TransferUserDefinedAggFunction(
81-
Class<? extends UserDefinedAggFunction> UDAF,
81+
/**
82+
* Creates a SqlUserDefinedAggFunction that wraps a Java class implementing an aggregate function.
83+
*
84+
* @param udafClass The Java class that implements the UserDefinedAggFunction interface
85+
* @param functionName The name of the function to be used in SQL statements
86+
* @param returnType A SqlReturnTypeInference that determines the return type of the function
87+
* @return A SqlUserDefinedAggFunction that can be used in SQL queries
88+
*/
89+
public static SqlUserDefinedAggFunction createUserDefinedAggFunction(
90+
Class<? extends UserDefinedAggFunction<?>> udafClass,
8291
String functionName,
83-
SqlReturnTypeInference returnType,
92+
SqlReturnTypeInference returnType) {
93+
return new SqlUserDefinedAggFunction(
94+
new SqlIdentifier(functionName, SqlParserPos.ZERO),
95+
SqlKind.OTHER_FUNCTION,
96+
returnType,
97+
null,
98+
null,
99+
AggregateFunctionImpl.create(udafClass),
100+
false,
101+
false,
102+
Optionality.FORBIDDEN);
103+
}
104+
105+
/**
106+
* Creates an aggregate call using the provided SqlAggFunction and arguments.
107+
*
108+
* @param aggFunction The aggregate function to call
109+
* @param fields The primary fields to aggregate
110+
* @param argList Additional arguments for the aggregate function
111+
* @param relBuilder The RelBuilder instance used for building relational expressions
112+
* @return An AggCall object representing the aggregate function call
113+
*/
114+
public static RelBuilder.AggCall makeAggregateCall(
115+
SqlAggFunction aggFunction,
84116
List<RexNode> fields,
85117
List<RexNode> argList,
86118
RelBuilder relBuilder) {
87-
SqlUserDefinedAggFunction sqlUDAF =
88-
new SqlUserDefinedAggFunction(
89-
new SqlIdentifier(functionName, SqlParserPos.ZERO),
90-
SqlKind.OTHER_FUNCTION,
91-
returnType,
92-
null,
93-
null,
94-
AggregateFunctionImpl.create(UDAF),
95-
false,
96-
false,
97-
Optionality.FORBIDDEN);
98119
List<RexNode> addArgList = new ArrayList<>(fields);
99120
addArgList.addAll(argList);
100-
return relBuilder.aggregateCall(sqlUDAF, addArgList);
121+
return relBuilder.aggregateCall(aggFunction, addArgList);
122+
}
123+
124+
/**
125+
* Creates and registers a User Defined Aggregate Function (UDAF) and returns an AggCall that can
126+
* be used in query plans.
127+
*
128+
* @param udafClass The class implementing the aggregate function behavior
129+
* @param functionName The name of the aggregate function
130+
* @param returnType The return type inference for determining the result type
131+
* @param fields The primary fields to aggregate
132+
* @param argList Additional arguments for the aggregate function
133+
* @param relBuilder The RelBuilder instance used for building relational expressions
134+
* @return An AggCall object representing the aggregate function call
135+
*/
136+
public static RelBuilder.AggCall createAggregateFunction(
137+
Class<? extends UserDefinedAggFunction<?>> udafClass,
138+
String functionName,
139+
SqlReturnTypeInference returnType,
140+
List<RexNode> fields,
141+
List<RexNode> argList,
142+
RelBuilder relBuilder) {
143+
SqlUserDefinedAggFunction udaf =
144+
createUserDefinedAggFunction(udafClass, functionName, returnType);
145+
return makeAggregateCall(udaf, fields, argList, relBuilder);
101146
}
102147

103148
public static SqlReturnTypeInference getReturnTypeInferenceForArray() {

0 commit comments

Comments
 (0)