Skip to content

Commit 47e6a2a

Browse files
authored
Support CASE function with Calcite (#3558)
* Support CASE function with Calcite Signed-off-by: Lantao Jin <ltjin@amazon.com> * add anonymizer Signed-off-by: Lantao Jin <ltjin@amazon.com> * address comment and fix varchar literal bug Signed-off-by: Lantao Jin <ltjin@amazon.com> * add doc Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix doctest Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix doctest Signed-off-by: Lantao Jin <ltjin@amazon.com> --------- Signed-off-by: Lantao Jin <ltjin@amazon.com>
1 parent 69b9e82 commit 47e6a2a

19 files changed

Lines changed: 530 additions & 23 deletions

File tree

core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ public Expression visitCase(Case node, AnalysisContext context) {
340340
}
341341

342342
Expression defaultResult =
343-
(node.getElseClause() == null) ? null : analyze(node.getElseClause(), context);
343+
node.getElseClause().map(elseClause -> analyze(elseClause, context)).orElse(null);
344344
CaseClause caseClause = new CaseClause(whens, defaultResult);
345345

346346
// To make this simple, require all result type same regardless of implicit convert

core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ public UnresolvedExpression caseWhen(UnresolvedExpression elseClause, When... wh
293293
*/
294294
public UnresolvedExpression caseWhen(
295295
UnresolvedExpression caseValueExpr, UnresolvedExpression elseClause, When... whenClauses) {
296-
return new Case(caseValueExpr, Arrays.asList(whenClauses), elseClause);
296+
return new Case(caseValueExpr, Arrays.asList(whenClauses), Optional.ofNullable(elseClause));
297297
}
298298

299299
public UnresolvedExpression cast(UnresolvedExpression expr, Literal type) {

core/src/main/java/org/opensearch/sql/ast/expression/Case.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import com.google.common.collect.ImmutableList;
99
import java.util.List;
10+
import java.util.Optional;
1011
import lombok.AllArgsConstructor;
1112
import lombok.EqualsAndHashCode;
1213
import lombok.Getter;
@@ -31,7 +32,7 @@ public class Case extends UnresolvedExpression {
3132
private final List<When> whenClauses;
3233

3334
/** Expression that represents ELSE statement result. */
34-
private final UnresolvedExpression elseClause;
35+
private final Optional<UnresolvedExpression> elseClause;
3536

3637
@Override
3738
public List<? extends Node> getChild() {
@@ -40,10 +41,7 @@ public List<? extends Node> getChild() {
4041
children.add(caseValue);
4142
}
4243
children.addAll(whenClauses);
43-
44-
if (elseClause != null) {
45-
children.add(elseClause);
46-
}
44+
elseClause.ifPresent(children::add);
4745
return children.build();
4846
}
4947

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedFunction;
1212

1313
import java.math.BigDecimal;
14+
import java.util.ArrayList;
1415
import java.util.List;
1516
import java.util.stream.Collectors;
1617
import lombok.RequiredArgsConstructor;
@@ -31,6 +32,7 @@
3132
import org.opensearch.sql.ast.expression.Alias;
3233
import org.opensearch.sql.ast.expression.And;
3334
import org.opensearch.sql.ast.expression.Between;
35+
import org.opensearch.sql.ast.expression.Case;
3436
import org.opensearch.sql.ast.expression.Cast;
3537
import org.opensearch.sql.ast.expression.Compare;
3638
import org.opensearch.sql.ast.expression.EqualTo;
@@ -89,7 +91,18 @@ public RexNode visitLiteral(Literal node, CalcitePlanContext context) {
8991
case NULL:
9092
return rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.NULL));
9193
case STRING:
92-
return rexBuilder.makeLiteral(value.toString());
94+
if (value.toString().length() == 1) {
95+
// To align Spark/PostgreSQL, Char(1) is useful, such as cast('1' to boolean) should
96+
// return true
97+
return rexBuilder.makeLiteral(
98+
value.toString(), typeFactory.createSqlType(SqlTypeName.CHAR));
99+
} else {
100+
// Specific the type to VARCHAR and allowCast to true, or the STRING will be optimized to
101+
// CHAR(n)
102+
// which leads to incorrect return type in deriveReturnType of some functions/operators
103+
return rexBuilder.makeLiteral(
104+
value.toString(), typeFactory.createSqlType(SqlTypeName.VARCHAR), true);
105+
}
93106
case INTEGER:
94107
return rexBuilder.makeExactLiteral(new BigDecimal((Integer) value));
95108
case LONG:
@@ -431,6 +444,19 @@ public RexNode visitCast(Cast node, CalcitePlanContext context) {
431444
return context.rexBuilder.makeCast(nullableType, expr, true, true);
432445
}
433446

447+
@Override
448+
public RexNode visitCase(Case node, CalcitePlanContext context) {
449+
List<RexNode> caseOperands = new ArrayList<>();
450+
for (When when : node.getWhenClauses()) {
451+
caseOperands.add(analyze(when.getCondition(), context));
452+
caseOperands.add(analyze(when.getResult(), context));
453+
}
454+
RexNode elseExpr =
455+
node.getElseClause().map(e -> analyze(e, context)).orElse(context.relBuilder.literal(null));
456+
caseOperands.add(elseExpr);
457+
return context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands);
458+
}
459+
434460
/*
435461
* Unsupported Expressions of PPL with Calcite for OpenSearch 3.0.0-beta
436462
*/

docs/user/ppl/functions/condition.rst

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,49 @@ Example::
192192
| False | Nanette | Bates |
193193
| False | Dale | Adams |
194194
+--------+-----------+----------+
195+
196+
CASE
197+
------
198+
199+
Description
200+
>>>>>>>>>>>
201+
202+
Usage: case(condition1, expr1, condition2, expr2, ... conditionN, exprN else default) return expr1 if condition1 is true, or return expr2 if condition2 is true, ... if no condition is true, then return the value of ELSE clause. If the ELSE clause is not defined, it returns NULL.
203+
204+
Argument type: all the supported data type, (NOTE : there is no comma before "else")
205+
206+
Return type: any
207+
208+
Example::
209+
210+
os> source=accounts | eval result = case(age > 35, firstname, age < 30, lastname else employer) | fields result, firstname, lastname, age, employer
211+
fetched rows / total rows = 4/4
212+
+--------+-----------+----------+-----+----------+
213+
| result | firstname | lastname | age | employer |
214+
|--------+-----------+----------+-----+----------|
215+
| Pyrami | Amber | Duke | 32 | Pyrami |
216+
| Hattie | Hattie | Bond | 36 | Netagy |
217+
| Bates | Nanette | Bates | 28 | Quility |
218+
| null | Dale | Adams | 33 | null |
219+
+--------+-----------+----------+-----+----------+
220+
221+
os> source=accounts | eval result = case(age > 35, firstname, age < 30, lastname) | fields result, firstname, lastname, age
222+
fetched rows / total rows = 4/4
223+
+--------+-----------+----------+-----+
224+
| result | firstname | lastname | age |
225+
|--------+-----------+----------+-----|
226+
| null | Amber | Duke | 32 |
227+
| Hattie | Hattie | Bond | 36 |
228+
| Bates | Nanette | Bates | 28 |
229+
| null | Dale | Adams | 33 |
230+
+--------+-----------+----------+-----+
231+
232+
os> source=accounts | where true = case(age > 35, false, age < 30, false else true) | fields firstname, lastname, age
233+
fetched rows / total rows = 2/2
234+
+-----------+----------+-----+
235+
| firstname | lastname | age |
236+
|-----------+----------+-----|
237+
| Amber | Duke | 32 |
238+
| Dale | Adams | 33 |
239+
+-----------+----------+-----+
240+
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.standalone;
7+
8+
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WEBLOGS;
9+
import static org.opensearch.sql.util.MatcherUtils.rows;
10+
import static org.opensearch.sql.util.MatcherUtils.schema;
11+
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;
12+
import static org.opensearch.sql.util.MatcherUtils.verifySchema;
13+
14+
import java.io.IOException;
15+
import org.json.JSONObject;
16+
import org.junit.jupiter.api.Test;
17+
import org.opensearch.client.Request;
18+
import org.opensearch.sql.legacy.TestsConstants;
19+
20+
public class CalcitePPLCaseFunctionIT extends CalcitePPLIntegTestCase {
21+
@Override
22+
public void init() throws IOException {
23+
super.init();
24+
loadIndex(Index.WEBLOG);
25+
appendDataForBadResponse();
26+
}
27+
28+
private void appendDataForBadResponse() throws IOException {
29+
Request request1 = new Request("PUT", "/" + TEST_INDEX_WEBLOGS + "/_doc/7?refresh=true");
30+
request1.setJsonEntity(
31+
"{\"host\": \"::1\", \"method\": \"GET\", \"url\": \"/history/apollo/\", \"response\":"
32+
+ " \"301\", \"bytes\": \"6245\"}");
33+
client().performRequest(request1);
34+
Request request2 =
35+
new Request("PUT", "/" + TestsConstants.TEST_INDEX_WEBLOGS + "/_doc/8?refresh=true");
36+
request2.setJsonEntity(
37+
"{\"host\": \"0.0.0.2\", \"method\": \"GET\", \"url\":"
38+
+ " \"/shuttle/missions/sts-73/mission-sts-73.html\", \"response\": \"500\", \"bytes\":"
39+
+ " \"4085\"}");
40+
client().performRequest(request2);
41+
Request request3 =
42+
new Request("PUT", "/" + TestsConstants.TEST_INDEX_WEBLOGS + "/_doc/9?refresh=true");
43+
request3.setJsonEntity(
44+
"{\"host\": \"::3\", \"method\": \"GET\", \"url\": \"/shuttle/countdown/countdown.html\","
45+
+ " \"response\": \"403\", \"bytes\": \"3985\"}");
46+
client().performRequest(request3);
47+
Request request4 =
48+
new Request("PUT", "/" + TestsConstants.TEST_INDEX_WEBLOGS + "/_doc/10?refresh=true");
49+
request4.setJsonEntity(
50+
"{\"host\": \"1.2.3.5\", \"method\": \"GET\", \"url\": \"/history/voyager2/\","
51+
+ " \"response\": null, \"bytes\": \"4321\"}");
52+
client().performRequest(request4);
53+
}
54+
55+
@Test
56+
public void testCaseWhenWithCast() {
57+
JSONObject actual =
58+
executeQuery(
59+
String.format(
60+
"""
61+
source=%s
62+
| eval status =
63+
case(
64+
cast(response as int) >= 200 AND cast(response as int) < 300, "Success",
65+
cast(response as int) >= 300 AND cast(response as int) < 400, "Redirection",
66+
cast(response as int) >= 400 AND cast(response as int) < 500, "Client Error",
67+
cast(response as int) >= 500 AND cast(response as int) < 600, "Server Error"
68+
else concat("Incorrect HTTP status code for", url))
69+
| where status != "Success"
70+
""",
71+
TEST_INDEX_WEBLOGS));
72+
verifySchema(
73+
actual,
74+
schema("host", "ip"),
75+
schema("method", "string"),
76+
schema("url", "string"),
77+
schema("response", "string"),
78+
schema("bytes", "string"),
79+
schema("status", "string"));
80+
verifyDataRows(
81+
actual,
82+
rows("::1", "GET", "6245", "301", "/history/apollo/", "Redirection"),
83+
rows(
84+
"0.0.0.2",
85+
"GET",
86+
"4085",
87+
"500",
88+
"/shuttle/missions/sts-73/mission-sts-73.html",
89+
"Server Error"),
90+
rows("::3", "GET", "3985", "403", "/shuttle/countdown/countdown.html", "Client Error"),
91+
rows(
92+
"1.2.3.5",
93+
"GET",
94+
"4321",
95+
null,
96+
"/history/voyager2/",
97+
"Incorrect HTTP status code for/history/voyager2/"));
98+
}
99+
100+
@Test
101+
public void testCaseWhenNoElse() {
102+
JSONObject actual =
103+
executeQuery(
104+
String.format(
105+
"""
106+
source=%s
107+
| eval status =
108+
case(
109+
cast(response as int) >= 200 AND cast(response as int) < 300, "Success",
110+
cast(response as int) >= 300 AND cast(response as int) < 400, "Redirection",
111+
cast(response as int) >= 400 AND cast(response as int) < 500, "Client Error",
112+
cast(response as int) >= 500 AND cast(response as int) < 600, "Server Error")
113+
| where isnull(status) OR status != "Success"
114+
""",
115+
TEST_INDEX_WEBLOGS));
116+
verifySchema(
117+
actual,
118+
schema("host", "ip"),
119+
schema("method", "string"),
120+
schema("url", "string"),
121+
schema("response", "string"),
122+
schema("bytes", "string"),
123+
schema("status", "string"));
124+
verifyDataRows(
125+
actual,
126+
rows("::1", "GET", "6245", "301", "/history/apollo/", "Redirection"),
127+
rows(
128+
"0.0.0.2",
129+
"GET",
130+
"4085",
131+
"500",
132+
"/shuttle/missions/sts-73/mission-sts-73.html",
133+
"Server Error"),
134+
rows("::3", "GET", "3985", "403", "/shuttle/countdown/countdown.html", "Client Error"),
135+
rows("1.2.3.5", "GET", "4321", null, "/history/voyager2/", null));
136+
}
137+
138+
@Test
139+
public void testCaseWhenWithIn() {
140+
JSONObject actual =
141+
executeQuery(
142+
String.format(
143+
"""
144+
source=%s
145+
| eval status =
146+
case(
147+
response in ('200'), "Success",
148+
response in ('300', '301'), "Redirection",
149+
response in ('400', '403'), "Client Error",
150+
response in ('500', '505'), "Server Error"
151+
else concat("Incorrect HTTP status code for", url))
152+
| where status != "Success"
153+
""",
154+
TEST_INDEX_WEBLOGS));
155+
verifySchema(
156+
actual,
157+
schema("host", "ip"),
158+
schema("method", "string"),
159+
schema("url", "string"),
160+
schema("response", "string"),
161+
schema("bytes", "string"),
162+
schema("status", "string"));
163+
verifyDataRows(
164+
actual,
165+
rows("::1", "GET", "6245", "301", "/history/apollo/", "Redirection"),
166+
rows(
167+
"0.0.0.2",
168+
"GET",
169+
"4085",
170+
"500",
171+
"/shuttle/missions/sts-73/mission-sts-73.html",
172+
"Server Error"),
173+
rows("::3", "GET", "3985", "403", "/shuttle/countdown/countdown.html", "Client Error"),
174+
rows(
175+
"1.2.3.5",
176+
"GET",
177+
"4321",
178+
null,
179+
"/history/voyager2/",
180+
"Incorrect HTTP status code for/history/voyager2/"));
181+
}
182+
183+
@Test
184+
public void testCaseWhenInFilter() {
185+
JSONObject actual =
186+
executeQuery(
187+
String.format(
188+
"""
189+
source=%s
190+
| where not true =
191+
case(
192+
response in ('200'), true,
193+
response in ('300', '301'), false,
194+
response in ('400', '403'), false,
195+
response in ('500', '505'), false
196+
else false)
197+
""",
198+
TEST_INDEX_WEBLOGS));
199+
verifySchema(
200+
actual,
201+
schema("host", "ip"),
202+
schema("method", "string"),
203+
schema("url", "string"),
204+
schema("response", "string"),
205+
schema("bytes", "string"));
206+
verifyDataRows(
207+
actual,
208+
rows("::1", "GET", "6245", "301", "/history/apollo/"),
209+
rows("0.0.0.2", "GET", "4085", "500", "/shuttle/missions/sts-73/mission-sts-73.html"),
210+
rows("::3", "GET", "3985", "403", "/shuttle/countdown/countdown.html"),
211+
rows("1.2.3.5", "GET", "4321", null, "/history/voyager2/"));
212+
}
213+
214+
@Test
215+
public void testCaseWhenInSubquery() {
216+
JSONObject actual =
217+
executeQuery(
218+
String.format(
219+
"""
220+
source=%s
221+
| where response in [
222+
source = %s
223+
| eval new_response = case(
224+
response in ('200'), "201",
225+
response in ('300', '301'), "301",
226+
response in ('400', '403'), "403",
227+
response in ('500', '505'), "500"
228+
else concat("Incorrect HTTP status code for", url))
229+
| fields new_response
230+
]
231+
""",
232+
TEST_INDEX_WEBLOGS, TEST_INDEX_WEBLOGS));
233+
verifySchema(
234+
actual,
235+
schema("host", "ip"),
236+
schema("method", "string"),
237+
schema("url", "string"),
238+
schema("response", "string"),
239+
schema("bytes", "string"));
240+
verifyDataRows(
241+
actual,
242+
rows("::1", "GET", "6245", "301", "/history/apollo/"),
243+
rows("0.0.0.2", "GET", "4085", "500", "/shuttle/missions/sts-73/mission-sts-73.html"),
244+
rows("::3", "GET", "3985", "403", "/shuttle/countdown/countdown.html"));
245+
}
246+
}

0 commit comments

Comments
 (0)