Skip to content

Commit c1d2d2d

Browse files
Merge pull request #12039 from ClickHouse/fix-nullable-tuple-compare
Fix nullable tuple compare
2 parents da26f64 + c97d071 commit c1d2d2d

3 files changed

Lines changed: 248 additions & 17 deletions

File tree

src/Functions/FunctionsComparison.h

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <Columns/ColumnArray.h>
1313

1414
#include <DataTypes/DataTypesNumber.h>
15+
#include <DataTypes/DataTypeNullable.h>
1516
#include <DataTypes/DataTypeDateTime.h>
1617
#include <DataTypes/DataTypeDateTime64.h>
1718
#include <DataTypes/DataTypeDate.h>
@@ -931,16 +932,19 @@ class FunctionComparison : public IFunction
931932
if (0 == tuple_size)
932933
throw Exception("Comparison of zero-sized tuples is not implemented.", ErrorCodes::NOT_IMPLEMENTED);
933934

935+
ColumnsWithTypeAndName convolution_types(tuple_size);
936+
934937
Block tmp_block;
935938
for (size_t i = 0; i < tuple_size; ++i)
936939
{
937940
tmp_block.insert(x[i]);
938941
tmp_block.insert(y[i]);
939942

940943
auto impl = func_compare->build({x[i], y[i]});
944+
convolution_types[i].type = impl->getReturnType();
941945

942946
/// Comparison of the elements.
943-
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
947+
tmp_block.insert({ nullptr, impl->getReturnType(), "" });
944948
impl->execute(tmp_block, {i * 3, i * 3 + 1}, i * 3 + 2, input_rows_count);
945949
}
946950

@@ -952,14 +956,13 @@ class FunctionComparison : public IFunction
952956
}
953957

954958
/// Logical convolution.
955-
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
956959

957960
ColumnNumbers convolution_args(tuple_size);
958961
for (size_t i = 0; i < tuple_size; ++i)
959962
convolution_args[i] = i * 3 + 2;
960963

961-
ColumnsWithTypeAndName convolution_types(convolution_args.size(), { nullptr, std::make_shared<DataTypeUInt8>(), "" });
962964
auto impl = func_convolution->build(convolution_types);
965+
tmp_block.insert({ nullptr, impl->getReturnType(), "" });
963966

964967
impl->execute(tmp_block, convolution_args, tuple_size * 3, input_rows_count);
965968
block.getByPosition(result).column = tmp_block.getByPosition(tuple_size * 3).column;
@@ -978,49 +981,71 @@ class FunctionComparison : public IFunction
978981
size_t tuple_size,
979982
size_t input_rows_count)
980983
{
981-
ColumnsWithTypeAndName bin_args = {{ nullptr, std::make_shared<DataTypeUInt8>(), "" },
982-
{ nullptr, std::make_shared<DataTypeUInt8>(), "" }};
983-
984-
auto func_and_adaptor = func_and->build(bin_args);
985-
auto func_or_adaptor = func_or->build(bin_args);
986-
987984
Block tmp_block;
988985

989986
/// Pairwise comparison of the inequality of all elements; on the equality of all elements except the last.
987+
/// (x[i], y[i], x[i] < y[i], x[i] == y[i])
990988
for (size_t i = 0; i < tuple_size; ++i)
991989
{
992990
tmp_block.insert(x[i]);
993991
tmp_block.insert(y[i]);
994992

995-
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
993+
tmp_block.insert(ColumnWithTypeAndName()); // pos == i * 4 + 2
996994

997995
if (i + 1 != tuple_size)
998996
{
999997
auto impl_head = func_compare_head->build({x[i], y[i]});
998+
tmp_block.getByPosition(i * 4 + 2).type = impl_head->getReturnType();
1000999
impl_head->execute(tmp_block, {i * 4, i * 4 + 1}, i * 4 + 2, input_rows_count);
10011000

1002-
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
1001+
tmp_block.insert(ColumnWithTypeAndName()); // i * 4 + 3
10031002

10041003
auto impl_equals = func_equals->build({x[i], y[i]});
1004+
tmp_block.getByPosition(i * 4 + 3).type = impl_equals->getReturnType();
10051005
impl_equals->execute(tmp_block, {i * 4, i * 4 + 1}, i * 4 + 3, input_rows_count);
10061006

10071007
}
10081008
else
10091009
{
10101010
auto impl_tail = func_compare_tail->build({x[i], y[i]});
1011+
tmp_block.getByPosition(i * 4 + 2).type = impl_tail->getReturnType();
10111012
impl_tail->execute(tmp_block, {i * 4, i * 4 + 1}, i * 4 + 2, input_rows_count);
10121013
}
10131014
}
10141015

10151016
/// Combination. Complex code - make a drawing. It can be replaced by a recursive comparison of tuples.
1017+
/// Last column contains intermediate result.
1018+
/// Code is generally equivalent to:
1019+
/// res = `x < y`[tuple_size - 1];
1020+
/// for (int i = tuple_size - 2; i >= 0; --i)
1021+
/// res = (res && `x == y`[i]) || `x < y`[i];
10161022
size_t i = tuple_size - 1;
10171023
while (i > 0)
10181024
{
1019-
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
1020-
func_and_adaptor->execute(tmp_block, {tmp_block.columns() - 2, (i - 1) * 4 + 3}, tmp_block.columns() - 1, input_rows_count);
1021-
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
1022-
func_or_adaptor->execute(tmp_block, {tmp_block.columns() - 2, (i - 1) * 4 + 2}, tmp_block.columns() - 1, input_rows_count);
10231025
--i;
1026+
1027+
size_t and_lhs_pos = tmp_block.columns() - 1; // res
1028+
size_t and_rhs_pos = i * 4 + 3; // `x == y`[i]
1029+
tmp_block.insert(ColumnWithTypeAndName());
1030+
1031+
ColumnsWithTypeAndName and_args = {{ nullptr, tmp_block.getByPosition(and_lhs_pos).type, "" },
1032+
{ nullptr, tmp_block.getByPosition(and_rhs_pos).type, "" }};
1033+
1034+
auto func_and_adaptor = func_and->build(and_args);
1035+
tmp_block.getByPosition(tmp_block.columns() - 1).type = func_and_adaptor->getReturnType();
1036+
func_and_adaptor->execute(tmp_block, {and_lhs_pos, and_rhs_pos}, tmp_block.columns() - 1, input_rows_count);
1037+
1038+
size_t or_lhs_pos = tmp_block.columns() - 1; // (res && `x == y`[i])
1039+
size_t or_rhs_pos = i * 4 + 2; // `x < y`[i]
1040+
tmp_block.insert(ColumnWithTypeAndName());
1041+
1042+
ColumnsWithTypeAndName or_args = {{ nullptr, tmp_block.getByPosition(or_lhs_pos).type, "" },
1043+
{ nullptr, tmp_block.getByPosition(or_rhs_pos).type, "" }};
1044+
1045+
auto func_or_adaptor = func_or->build(or_args);
1046+
tmp_block.getByPosition(tmp_block.columns() - 1).type = func_or_adaptor->getReturnType();
1047+
func_or_adaptor->execute(tmp_block, {or_lhs_pos, or_rhs_pos}, tmp_block.columns() - 1, input_rows_count);
1048+
10241049
}
10251050

10261051
block.getByPosition(result).column = tmp_block.getByPosition(tmp_block.columns() - 1).column;
@@ -1109,13 +1134,20 @@ class FunctionComparison : public IFunction
11091134
auto adaptor = FunctionOverloadResolverAdaptor(std::make_unique<DefaultOverloadResolver>(
11101135
FunctionComparison<Op, Name>::create(context)));
11111136

1137+
bool has_nullable = false;
1138+
11121139
size_t size = left_tuple->getElements().size();
11131140
for (size_t i = 0; i < size; ++i)
11141141
{
11151142
ColumnsWithTypeAndName args = {{nullptr, left_tuple->getElements()[i], ""},
11161143
{nullptr, right_tuple->getElements()[i], ""}};
1117-
adaptor.build(args);
1144+
has_nullable = has_nullable || adaptor.build(args)->getReturnType()->isNullable();
11181145
}
1146+
1147+
/// If any element comparison is nullable, return type will also be nullable.
1148+
/// We useDefaultImplementationForNulls, but it doesn't work for tuples.
1149+
if (has_nullable)
1150+
return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeUInt8>());
11191151
}
11201152

11211153
return std::make_shared<DataTypeUInt8>();
@@ -1135,7 +1167,7 @@ class FunctionComparison : public IFunction
11351167
/// NOTE: Nullable types are special case.
11361168
/// (BTW, this function use default implementation for Nullable, so Nullable types cannot be here. Check just in case.)
11371169
/// NOTE: We consider NaN comparison to be implementation specific (and in our implementation NaNs are sometimes equal sometimes not).
1138-
if (left_type->equals(*right_type) && !left_type->isNullable() && col_left_untyped == col_right_untyped)
1170+
if (left_type->equals(*right_type) && !left_type->isNullable() && !isTuple(left_type) && col_left_untyped == col_right_untyped)
11391171
{
11401172
/// Always true: =, <=, >=
11411173
if constexpr (std::is_same_v<Op<int, int>, EqualsOp<int, int>>
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
single argument
2+
1
3+
0
4+
1
5+
0
6+
1
7+
0
8+
- 1
9+
1
10+
1
11+
1
12+
0
13+
0
14+
0
15+
0
16+
0
17+
0
18+
1
19+
1
20+
1
21+
- 2
22+
1
23+
1
24+
1
25+
0
26+
0
27+
0
28+
0
29+
0
30+
1
31+
1
32+
1
33+
1
34+
- 3
35+
1
36+
1
37+
1
38+
1
39+
1
40+
1
41+
- 4
42+
\N
43+
\N
44+
\N
45+
\N
46+
\N
47+
\N
48+
two arguments
49+
1
50+
1
51+
1
52+
1
53+
1
54+
1
55+
- 1
56+
0
57+
0
58+
0
59+
0
60+
0
61+
0
62+
- 2
63+
1
64+
1
65+
1
66+
1
67+
1
68+
1
69+
- 3
70+
\N
71+
\N
72+
\N
73+
\N
74+
\N
75+
1
76+
\N
77+
\N
78+
0
79+
many arguments
80+
1
81+
1
82+
0
83+
0
84+
1
85+
0
86+
1
87+
\N
88+
\N
89+
\N
90+
\N
91+
\N
92+
\N
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
select 'single argument';
2+
select tuple(number) = tuple(number) from numbers(1);
3+
select tuple(number) = tuple(number + 1) from numbers(1);
4+
select tuple(toNullable(number)) = tuple(number) from numbers(1);
5+
select tuple(toNullable(number)) = tuple(number + 1) from numbers(1);
6+
select tuple(toNullable(number)) = tuple(toNullable(number)) from numbers(1);
7+
select tuple(toNullable(number)) = tuple(toNullable(number + 1)) from numbers(1);
8+
select '- 1';
9+
select tuple(toNullable(number)) < tuple(number + 1) from numbers(1);
10+
select tuple(number) < tuple(toNullable(number + 1)) from numbers(1);
11+
select tuple(toNullable(number)) < tuple(toNullable(number + 1)) from numbers(1);
12+
13+
select tuple(toNullable(number)) > tuple(number + 1) from numbers(1);
14+
select tuple(number) > tuple(toNullable(number + 1)) from numbers(1);
15+
select tuple(toNullable(number)) > tuple(toNullable(number + 1)) from numbers(1);
16+
17+
select tuple(toNullable(number + 1)) < tuple(number) from numbers(1);
18+
select tuple(number + 1) < tuple(toNullable(number)) from numbers(1);
19+
select tuple(toNullable(number + 1)) < tuple(toNullable(number + 1)) from numbers(1);
20+
21+
select tuple(toNullable(number + 1)) > tuple(number) from numbers(1);
22+
select tuple(number + 1) > tuple(toNullable(number)) from numbers(1);
23+
select tuple(toNullable(number + 1)) > tuple(toNullable(number)) from numbers(1);
24+
25+
select '- 2';
26+
select tuple(toNullable(number)) <= tuple(number + 1) from numbers(1);
27+
select tuple(number) <= tuple(toNullable(number + 1)) from numbers(1);
28+
select tuple(toNullable(number)) <= tuple(toNullable(number + 1)) from numbers(1);
29+
30+
select tuple(toNullable(number)) >= tuple(number + 1) from numbers(1);
31+
select tuple(number) > tuple(toNullable(number + 1)) from numbers(1);
32+
select tuple(toNullable(number)) >= tuple(toNullable(number + 1)) from numbers(1);
33+
34+
select tuple(toNullable(number + 1)) <= tuple(number) from numbers(1);
35+
select tuple(number + 1) <= tuple(toNullable(number)) from numbers(1);
36+
select tuple(toNullable(number + 1)) <= tuple(toNullable(number + 1)) from numbers(1);
37+
38+
select tuple(toNullable(number + 1)) >= tuple(number) from numbers(1);
39+
select tuple(number + 1) >= tuple(toNullable(number)) from numbers(1);
40+
select tuple(toNullable(number + 1)) >= tuple(toNullable(number)) from numbers(1);
41+
42+
select '- 3';
43+
select tuple(toNullable(number)) <= tuple(number) from numbers(1);
44+
select tuple(number) <= tuple(toNullable(number)) from numbers(1);
45+
select tuple(toNullable(number)) <= tuple(toNullable(number)) from numbers(1);
46+
47+
select tuple(toNullable(number)) >= tuple(number) from numbers(1);
48+
select tuple(number) >= tuple(toNullable(number)) from numbers(1);
49+
select tuple(toNullable(number)) >= tuple(toNullable(number)) from numbers(1);
50+
51+
select '- 4';
52+
select tuple(number) = tuple(materialize(toUInt64OrNull(''))) from numbers(1);
53+
select tuple(materialize(toUInt64OrNull(''))) = tuple(materialize(toUInt64OrNull(''))) from numbers(1);
54+
select tuple(number) <= tuple(materialize(toUInt64OrNull(''))) from numbers(1);
55+
select tuple(materialize(toUInt64OrNull(''))) <= tuple(materialize(toUInt64OrNull(''))) from numbers(1);
56+
select tuple(number) >= tuple(materialize(toUInt64OrNull(''))) from numbers(1);
57+
select tuple(materialize(toUInt64OrNull(''))) >= tuple(materialize(toUInt64OrNull(''))) from numbers(1);
58+
59+
select 'two arguments';
60+
select tuple(toNullable(number), number) = tuple(number, number) from numbers(1);
61+
select tuple(toNullable(number), toNullable(number)) = tuple(number, number) from numbers(1);
62+
select tuple(toNullable(number), toNullable(number)) = tuple(toNullable(number), number) from numbers(1);
63+
select tuple(toNullable(number), toNullable(number)) = tuple(toNullable(number), toNullable(number)) from numbers(1);
64+
select tuple(number, toNullable(number)) = tuple(toNullable(number), toNullable(number)) from numbers(1);
65+
select tuple(number, toNullable(number)) = tuple(toNullable(number), number) from numbers(1);
66+
67+
select '- 1';
68+
select tuple(toNullable(number), number) < tuple(number, number) from numbers(1);
69+
select tuple(toNullable(number), toNullable(number)) < tuple(number, number) from numbers(1);
70+
select tuple(toNullable(number), toNullable(number)) < tuple(toNullable(number), number) from numbers(1);
71+
select tuple(toNullable(number), toNullable(number)) < tuple(toNullable(number), toNullable(number)) from numbers(1);
72+
select tuple(number, toNullable(number)) < tuple(toNullable(number), toNullable(number)) from numbers(1);
73+
select tuple(number, toNullable(number)) < tuple(toNullable(number), number) from numbers(1);
74+
75+
select '- 2';
76+
select tuple(toNullable(number), number) < tuple(number, number + 1) from numbers(1);
77+
select tuple(toNullable(number), toNullable(number)) < tuple(number, number + 1) from numbers(1);
78+
select tuple(toNullable(number), toNullable(number)) < tuple(toNullable(number + 1), number) from numbers(1);
79+
select tuple(toNullable(number), toNullable(number)) < tuple(toNullable(number + 1), toNullable(number)) from numbers(1);
80+
select tuple(number, toNullable(number)) < tuple(toNullable(number), toNullable(number + 1)) from numbers(1);
81+
select tuple(number, toNullable(number)) < tuple(toNullable(number), number + 1) from numbers(1);
82+
83+
select '- 3';
84+
select tuple(materialize(toUInt64OrNull('')), number) = tuple(number, number) from numbers(1);
85+
select tuple(materialize(toUInt64OrNull('')), number) = tuple(number, toUInt64OrNull('')) from numbers(1);
86+
select tuple(materialize(toUInt64OrNull('')), toUInt64OrNull('')) = tuple(toUInt64OrNull(''), toUInt64OrNull('')) from numbers(1);
87+
select tuple(number, materialize(toUInt64OrNull(''))) < tuple(number, number) from numbers(1);
88+
select tuple(number, materialize(toUInt64OrNull(''))) <= tuple(number, number) from numbers(1);
89+
select tuple(number, materialize(toUInt64OrNull(''))) < tuple(number + 1, number) from numbers(1);
90+
select tuple(number, materialize(toUInt64OrNull(''))) > tuple(number, number) from numbers(1);
91+
select tuple(number, materialize(toUInt64OrNull(''))) >= tuple(number, number) from numbers(1);
92+
select tuple(number, materialize(toUInt64OrNull(''))) > tuple(number + 1, number) from numbers(1);
93+
94+
select 'many arguments';
95+
select tuple(toNullable(number), number, number) = tuple(number, number, number) from numbers(1);
96+
select tuple(toNullable(number), materialize('a'), number) = tuple(number, materialize('a'), number) from numbers(1);
97+
select tuple(toNullable(number), materialize('a'), number) = tuple(number, materialize('a'), number + 1) from numbers(1);
98+
select tuple(toNullable(number), number, number) < tuple(number, number, number) from numbers(1);
99+
select tuple(toNullable(number), number, number) <= tuple(number, number, number) from numbers(1);
100+
select tuple(toNullable(number), materialize('a'), number) < tuple(number, materialize('a'), number) from numbers(1);
101+
select tuple(toNullable(number), materialize('a'), number) < tuple(number, materialize('a'), number + 1) from numbers(1);
102+
select tuple(toNullable(number), number, materialize(toUInt64OrNull(''))) = tuple(number, number, number) from numbers(1);
103+
select tuple(toNullable(number), materialize('a'), materialize(toUInt64OrNull(''))) = tuple(number, materialize('a'), number) from numbers(1);
104+
select tuple(toNullable(number), materialize('a'), materialize(toUInt64OrNull(''))) = tuple(number, materialize('a'), number + 1) from numbers(1);
105+
select tuple(toNullable(number), number, materialize(toUInt64OrNull(''))) <= tuple(number, number, number) from numbers(1);
106+
select tuple(toNullable(number), materialize('a'), materialize(toUInt64OrNull(''))) <= tuple(number, materialize('a'), number) from numbers(1);
107+
select tuple(toNullable(number), materialize('a'), materialize(toUInt64OrNull(''))) <= tuple(number, materialize('a'), number + 1) from numbers(1);

0 commit comments

Comments
 (0)