Skip to content

Commit 484f49e

Browse files
authored
Support function argument coercion with Calcite (opensearch-project#3914)
* Change the use of SqlTypeFamily.STRING to SqlTypeFamily.CHARACTER as the string family contains binary, which is not expected for most functions Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Implement basic argument type coercion at RelNode level Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Conform type checkers with their definition in documentation - string as an input is removed if it is not in the document - string as an input is kept if it is in the document, even if it can be implicitly cast - use PPLOperandTypes as much as possible Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Implement type widening for comparator functions - Add COMPARATORS set to BuiltinFunctionName for identifying comparison operators - Implement widenArguments method in CoercionUtils to find widest compatible type - Apply type widening to comparator functions before applying type casting - Add detailed JavaDoc to explain coercion methods Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Update error messages of datetime functions with invalid args Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Simplify datetime-string compare logic with implict coercion Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Refactor resolve with coercion Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Move down argument cast for reduce function Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Merge comparators and their IP variants so that coercion works for IP comparison - when not merging, ip comparing will also pass the type checker of Calcite's comparators Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Refactor ip comparator to comparator Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Revert "Refactor ip comparator to comparator" This reverts commit c539056. Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Revert "Merge comparators and their IP variants so that coercion works for IP comparison" This reverts commit bd9f3bb. Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Rule out ip from built-in comparator via its type checker Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Restrict CompareIP's parameter type Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Revert to previous implementation of CompareIpFunction to temporarily fix ip comparison pushdown problems (udt not correctly serialized; ip comparison is not converted to range query) Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Test argument coercion explain Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Fix error msg in CalcitePPLFunctionTypeTest Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> --------- Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent edb3a0d commit 484f49e

41 files changed

Lines changed: 666 additions & 338 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import static org.opensearch.sql.ast.expression.SpanUnit.NONE;
1212
import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN;
1313
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;
14-
import static org.opensearch.sql.utils.DateTimeUtils.findCastType;
15-
import static org.opensearch.sql.utils.DateTimeUtils.transferCompareForDateRelated;
1614

1715
import java.math.BigDecimal;
1816
import java.util.ArrayList;
@@ -30,7 +28,6 @@
3028
import org.apache.calcite.rel.type.RelDataTypeFactory;
3129
import org.apache.calcite.rex.RexBuilder;
3230
import org.apache.calcite.rex.RexCall;
33-
import org.apache.calcite.rex.RexLambda;
3431
import org.apache.calcite.rex.RexLambdaRef;
3532
import org.apache.calcite.rex.RexNode;
3633
import org.apache.calcite.sql.SqlIntervalQualifier;
@@ -215,11 +212,8 @@ public RexNode visitIn(In node, CalcitePlanContext context) {
215212

216213
@Override
217214
public RexNode visitCompare(Compare node, CalcitePlanContext context) {
218-
RexNode leftCandidate = analyze(node.getLeft(), context);
219-
RexNode rightCandidate = analyze(node.getRight(), context);
220-
SqlTypeName castTarget = findCastType(leftCandidate, rightCandidate);
221-
final RexNode left = transferCompareForDateRelated(leftCandidate, context, castTarget);
222-
final RexNode right = transferCompareForDateRelated(rightCandidate, context, castTarget);
215+
RexNode left = analyze(node.getLeft(), context);
216+
RexNode right = analyze(node.getRight(), context);
223217
return PPLFuncImpTable.INSTANCE.resolve(context.rexBuilder, node.getOperator(), left, right);
224218
}
225219

@@ -468,19 +462,6 @@ private List<RelDataType> modifyLambdaTypeByFunction(
468462
}
469463
}
470464

471-
private List<RexNode> castArgument(
472-
List<RexNode> originalArguments, String functionName, ExtendedRexBuilder rexBuilder) {
473-
switch (functionName.toUpperCase(Locale.ROOT)) {
474-
case "REDUCE":
475-
RexLambda call = (RexLambda) originalArguments.get(2);
476-
originalArguments.set(
477-
1, rexBuilder.makeCast(call.getType(), originalArguments.get(1), true, true));
478-
return originalArguments;
479-
default:
480-
return originalArguments;
481-
}
482-
}
483-
484465
@Override
485466
public RexNode visitFunction(Function node, CalcitePlanContext context) {
486467
List<UnresolvedExpression> args = node.getFuncArgs();
@@ -507,8 +488,6 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
507488
}
508489
}
509490

510-
arguments = castArgument(arguments, node.getFuncName(), context.rexBuilder);
511-
512491
RexNode resolvedNode =
513492
PPLFuncImpTable.INSTANCE.resolve(
514493
context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0]));

core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java

Lines changed: 93 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,40 +27,121 @@ private PPLOperandTypes() {}
2727
UDFOperandMetadata.wrap(
2828
(CompositeOperandTypeChecker) OperandTypes.INTEGER.or(OperandTypes.family()));
2929
public static final UDFOperandMetadata STRING =
30-
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.STRING);
30+
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER);
3131
public static final UDFOperandMetadata INTEGER =
3232
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER);
3333
public static final UDFOperandMetadata NUMERIC =
3434
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC);
35+
36+
public static final UDFOperandMetadata NUMERIC_OPTIONAL_STRING =
37+
UDFOperandMetadata.wrap(
38+
(CompositeOperandTypeChecker)
39+
OperandTypes.NUMERIC.or(
40+
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER)));
41+
3542
public static final UDFOperandMetadata INTEGER_INTEGER =
3643
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER);
3744
public static final UDFOperandMetadata STRING_STRING =
38-
UDFOperandMetadata.wrap(OperandTypes.STRING_STRING);
45+
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER_CHARACTER);
3946
public static final UDFOperandMetadata NUMERIC_NUMERIC =
4047
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC_NUMERIC);
48+
public static final UDFOperandMetadata STRING_INTEGER =
49+
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER));
50+
4151
public static final UDFOperandMetadata NUMERIC_NUMERIC_NUMERIC =
4252
UDFOperandMetadata.wrap(
4353
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));
54+
public static final UDFOperandMetadata STRING_OR_INTEGER_INTEGER_INTEGER =
55+
UDFOperandMetadata.wrap(
56+
(CompositeOperandTypeChecker)
57+
OperandTypes.family(
58+
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
59+
.or(
60+
OperandTypes.family(
61+
SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)));
62+
63+
public static final UDFOperandMetadata OPTIONAL_DATE_OR_TIMESTAMP_OR_NUMERIC =
64+
UDFOperandMetadata.wrap(
65+
(CompositeOperandTypeChecker)
66+
OperandTypes.DATETIME.or(OperandTypes.NUMERIC).or(OperandTypes.family()));
4467

4568
public static final UDFOperandMetadata DATETIME_OR_STRING =
4669
UDFOperandMetadata.wrap(
47-
(CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.STRING));
70+
(CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.CHARACTER));
71+
public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING =
72+
UDFOperandMetadata.wrap(
73+
(CompositeOperandTypeChecker)
74+
OperandTypes.CHARACTER.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP));
75+
public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING =
76+
UDFOperandMetadata.wrap(
77+
(CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.CHARACTER));
78+
public static final UDFOperandMetadata DATETIME_OR_STRING_OR_INTEGER =
79+
UDFOperandMetadata.wrap(
80+
(CompositeOperandTypeChecker)
81+
OperandTypes.DATETIME.or(OperandTypes.CHARACTER).or(OperandTypes.INTEGER));
82+
83+
public static final UDFOperandMetadata DATETIME_OPTIONAL_INTEGER =
84+
UDFOperandMetadata.wrap(
85+
(CompositeOperandTypeChecker)
86+
OperandTypes.DATETIME.or(
87+
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER)));
88+
4889
public static final UDFOperandMetadata DATETIME_DATETIME =
4990
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME));
91+
public static final UDFOperandMetadata DATETIME_OR_STRING_STRING =
92+
UDFOperandMetadata.wrap(
93+
(CompositeOperandTypeChecker)
94+
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER)
95+
.or(OperandTypes.CHARACTER_CHARACTER));
5096
public static final UDFOperandMetadata DATETIME_OR_STRING_DATETIME_OR_STRING =
5197
UDFOperandMetadata.wrap(
5298
(CompositeOperandTypeChecker)
53-
OperandTypes.STRING_STRING
99+
OperandTypes.CHARACTER_CHARACTER
54100
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME))
55-
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING))
56-
.or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME)));
57-
public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING =
101+
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))
102+
.or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)));
103+
public static final UDFOperandMetadata STRING_TIMESTAMP =
104+
UDFOperandMetadata.wrap(
105+
OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP));
106+
public static final UDFOperandMetadata STRING_DATETIME =
107+
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME));
108+
public static final UDFOperandMetadata DATETIME_INTERVAL =
109+
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.DATETIME_INTERVAL);
110+
public static final UDFOperandMetadata TIME_TIME =
111+
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.TIME, SqlTypeFamily.TIME));
112+
113+
public static final UDFOperandMetadata TIMESTAMP_OR_STRING_STRING_STRING =
58114
UDFOperandMetadata.wrap(
59115
(CompositeOperandTypeChecker)
60-
OperandTypes.STRING.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP));
61-
public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING =
116+
OperandTypes.family(
117+
SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
118+
.or(
119+
OperandTypes.family(
120+
SqlTypeFamily.CHARACTER,
121+
SqlTypeFamily.CHARACTER,
122+
SqlTypeFamily.CHARACTER)));
123+
public static final UDFOperandMetadata STRING_INTEGER_DATETIME_OR_STRING =
62124
UDFOperandMetadata.wrap(
63-
(CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.STRING));
64-
public static final UDFOperandMetadata STRING_TIMESTAMP =
65-
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.TIMESTAMP));
125+
(CompositeOperandTypeChecker)
126+
OperandTypes.family(
127+
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
128+
.or(
129+
OperandTypes.family(
130+
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.DATETIME)));
131+
public static final UDFOperandMetadata INTERVAL_DATETIME_DATETIME =
132+
UDFOperandMetadata.wrap(
133+
(CompositeOperandTypeChecker)
134+
OperandTypes.family(
135+
SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)
136+
.or(
137+
OperandTypes.family(
138+
SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME))
139+
.or(
140+
OperandTypes.family(
141+
SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))
142+
.or(
143+
OperandTypes.family(
144+
SqlTypeFamily.CHARACTER,
145+
SqlTypeFamily.CHARACTER,
146+
SqlTypeFamily.CHARACTER)));
66147
}

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.Locale;
1010
import java.util.Map;
1111
import java.util.Optional;
12+
import java.util.Set;
1213
import lombok.AllArgsConstructor;
1314
import lombok.Getter;
1415
import lombok.RequiredArgsConstructor;
@@ -380,4 +381,13 @@ public static Optional<BuiltinFunctionName> ofWindowFunction(String functionName
380381
return Optional.ofNullable(
381382
WINDOW_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null));
382383
}
384+
385+
public static final Set<BuiltinFunctionName> COMPARATORS =
386+
Set.of(
387+
BuiltinFunctionName.EQUAL,
388+
BuiltinFunctionName.NOTEQUAL,
389+
BuiltinFunctionName.LESS,
390+
BuiltinFunctionName.LTE,
391+
BuiltinFunctionName.GREATER,
392+
BuiltinFunctionName.GTE);
383393
}

core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
/** Function signature is composed by function name and arguments list. */
1212
public record CalciteFuncSignature(FunctionName functionName, PPLTypeChecker typeChecker) {
1313

14-
public boolean match(FunctionName functionName, List<RelDataType> paramTypeList) {
14+
public boolean match(FunctionName functionName, List<RelDataType> argTypes) {
1515
if (!functionName.equals(this.functionName())) return false;
1616
// For complex type checkers (e.g., OperandTypes.COMPARABLE_UNORDERED_COMPARABLE_UNORDERED),
1717
// the typeChecker will be null because only simple family-based type checks are currently
1818
// supported.
1919
if (typeChecker == null) return true;
20-
return typeChecker.checkOperandTypes(paramTypeList);
20+
return typeChecker.checkOperandTypes(argTypes);
2121
}
2222
}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.expression.function;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
import javax.annotation.Nullable;
11+
import org.apache.calcite.rex.RexBuilder;
12+
import org.apache.calcite.rex.RexNode;
13+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
14+
import org.opensearch.sql.data.type.ExprCoreType;
15+
import org.opensearch.sql.data.type.ExprType;
16+
import org.opensearch.sql.data.type.WideningTypeRule;
17+
import org.opensearch.sql.exception.ExpressionEvaluationException;
18+
19+
public class CoercionUtils {
20+
21+
/**
22+
* Casts the arguments to the types specified in the typeChecker. Returns null if no combination
23+
* of parameter types matches the arguments or if casting fails.
24+
*
25+
* @param builder RexBuilder to create casts
26+
* @param typeChecker PPLTypeChecker that provides the parameter types
27+
* @param arguments List of RexNode arguments to be cast
28+
* @return List of cast RexNode arguments or null if casting fails
29+
*/
30+
public static @Nullable List<RexNode> castArguments(
31+
RexBuilder builder, PPLTypeChecker typeChecker, List<RexNode> arguments) {
32+
List<List<ExprType>> paramTypeCombinations = typeChecker.getParameterTypes();
33+
34+
// TODO: var args?
35+
36+
for (List<ExprType> paramTypes : paramTypeCombinations) {
37+
List<RexNode> castedArguments = castArguments(builder, paramTypes, arguments);
38+
if (castedArguments != null) {
39+
return castedArguments;
40+
}
41+
}
42+
return null;
43+
}
44+
45+
/**
46+
* Widen the arguments to the widest type found among them. If no widest type can be determined,
47+
* returns null.
48+
*
49+
* @param builder RexBuilder to create casts
50+
* @param arguments List of RexNode arguments to be widened
51+
* @return List of widened RexNode arguments or null if no widest type can be determined
52+
*/
53+
public static @Nullable List<RexNode> widenArguments(
54+
RexBuilder builder, List<RexNode> arguments) {
55+
// TODO: Add test on e.g. IP
56+
ExprType widestType = findWidestType(arguments);
57+
if (widestType == null) {
58+
return null; // No widest type found, return null
59+
}
60+
return arguments.stream().map(arg -> cast(builder, widestType, arg)).toList();
61+
}
62+
63+
/**
64+
* Casts the arguments to the types specified in paramTypes. Returns null if the number of
65+
* parameters does not match or if casting fails.
66+
*/
67+
private static @Nullable List<RexNode> castArguments(
68+
RexBuilder builder, List<ExprType> paramTypes, List<RexNode> arguments) {
69+
if (paramTypes.size() != arguments.size()) {
70+
return null; // Skip if the number of parameters does not match
71+
}
72+
73+
List<RexNode> castedArguments = new ArrayList<>();
74+
for (int i = 0; i < paramTypes.size(); i++) {
75+
ExprType toType = paramTypes.get(i);
76+
RexNode arg = arguments.get(i);
77+
78+
RexNode castedArg = cast(builder, toType, arg);
79+
80+
if (castedArg == null) {
81+
return null;
82+
}
83+
castedArguments.add(castedArg);
84+
}
85+
return castedArguments;
86+
}
87+
88+
private static @Nullable RexNode cast(RexBuilder builder, ExprType targetType, RexNode arg) {
89+
ExprType argType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arg.getType());
90+
if (!argType.shouldCast(targetType)) {
91+
return arg;
92+
}
93+
94+
if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) {
95+
return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg);
96+
}
97+
return null;
98+
}
99+
100+
/**
101+
* Finds the widest type among the given arguments. The widest type is determined by applying the
102+
* widening type rule to each pair of types in the arguments.
103+
*
104+
* @param arguments List of RexNode arguments to find the widest type from
105+
* @return the widest ExprType if found, otherwise null
106+
*/
107+
private static @Nullable ExprType findWidestType(List<RexNode> arguments) {
108+
if (arguments.isEmpty()) {
109+
return null; // No arguments to process
110+
}
111+
ExprType widestType =
112+
OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.getFirst().getType());
113+
if (arguments.size() == 1) {
114+
return widestType;
115+
}
116+
117+
// Iterate pairwise through the arguments and find the widest type
118+
for (int i = 1; i < arguments.size(); i++) {
119+
var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType());
120+
try {
121+
if (areDateAndTime(widestType, type)) {
122+
// If one is date and the other is time, we consider timestamp as the widest type
123+
widestType = ExprCoreType.TIMESTAMP;
124+
} else {
125+
widestType = WideningTypeRule.max(widestType, type);
126+
}
127+
} catch (ExpressionEvaluationException e) {
128+
// the two types are not compatible, return null
129+
return null;
130+
}
131+
}
132+
return widestType;
133+
}
134+
135+
private static boolean areDateAndTime(ExprType type1, ExprType type2) {
136+
return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME)
137+
|| (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE);
138+
}
139+
}

0 commit comments

Comments
 (0)