Skip to content

Commit f6f8dc9

Browse files
Merge pull request #13964 from zhang2014/fix/agg_combinator
Try fix IfAggCombinator with NullAggCombinator
2 parents 57af010 + 75af61e commit f6f8dc9

9 files changed

Lines changed: 181 additions & 5 deletions

src/AggregateFunctions/AggregateFunctionCount.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace DB
88
{
99

1010
AggregateFunctionPtr AggregateFunctionCount::getOwnNullAdapter(
11-
const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const
11+
const AggregateFunctionPtr &, const DataTypes & types, const Array & params, const AggregateFunctionProperties & /*properties*/) const
1212
{
1313
return std::make_shared<AggregateFunctionCountNotNullUnary>(types[0], params);
1414
}

src/AggregateFunctions/AggregateFunctionCount.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class AggregateFunctionCount final : public IAggregateFunctionDataHelper<Aggrega
6969
}
7070

7171
AggregateFunctionPtr getOwnNullAdapter(
72-
const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const override;
72+
const AggregateFunctionPtr &, const DataTypes & types, const Array & params, const AggregateFunctionProperties & /*properties*/) const override;
7373
};
7474

7575

src/AggregateFunctions/AggregateFunctionIf.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#include <AggregateFunctions/AggregateFunctionIf.h>
22
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
33
#include "registerAggregateFunctions.h"
4+
#include "AggregateFunctionNull.h"
45

56

67
namespace DB
78
{
89

910
namespace ErrorCodes
1011
{
12+
extern const int LOGICAL_ERROR;
1113
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
1214
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
1315
}
@@ -40,6 +42,164 @@ class AggregateFunctionCombinatorIf final : public IAggregateFunctionCombinator
4042
}
4143
};
4244

45+
/** There are two cases: for single argument and variadic.
46+
* Code for single argument is much more efficient.
47+
*/
48+
template <bool result_is_nullable, bool serialize_flag>
49+
class AggregateFunctionIfNullUnary final
50+
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
51+
AggregateFunctionIfNullUnary<result_is_nullable, serialize_flag>>
52+
{
53+
private:
54+
size_t num_arguments;
55+
56+
using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,
57+
AggregateFunctionIfNullUnary<result_is_nullable, serialize_flag>>;
58+
public:
59+
60+
String getName() const override
61+
{
62+
return Base::getName() + "If";
63+
}
64+
65+
AggregateFunctionIfNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
66+
: Base(std::move(nested_function_), arguments, params), num_arguments(arguments.size())
67+
{
68+
if (num_arguments == 0)
69+
throw Exception("Aggregate function " + getName() + " require at least one argument",
70+
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
71+
}
72+
73+
static inline bool singleFilter(const IColumn ** columns, size_t row_num, size_t num_arguments)
74+
{
75+
const IColumn * filter_column = columns[num_arguments - 1];
76+
if (const ColumnNullable * nullable_column = typeid_cast<const ColumnNullable *>(filter_column))
77+
filter_column = nullable_column->getNestedColumnPtr().get();
78+
79+
return assert_cast<const ColumnUInt8 &>(*filter_column).getData()[row_num];
80+
}
81+
82+
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
83+
{
84+
const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
85+
const IColumn * nested_column = &column->getNestedColumn();
86+
if (!column->isNullAt(row_num) && singleFilter(columns, row_num, num_arguments))
87+
{
88+
this->setFlag(place);
89+
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
90+
}
91+
}
92+
};
93+
94+
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
95+
class AggregateFunctionIfNullVariadic final
96+
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
97+
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>
98+
{
99+
public:
100+
101+
String getName() const override
102+
{
103+
return Base::getName() + "If";
104+
}
105+
106+
AggregateFunctionIfNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
107+
: Base(std::move(nested_function_), arguments, params), number_of_arguments(arguments.size())
108+
{
109+
if (number_of_arguments == 1)
110+
throw Exception("Logical error: single argument is passed to AggregateFunctionIfNullVariadic", ErrorCodes::LOGICAL_ERROR);
111+
112+
if (number_of_arguments > MAX_ARGS)
113+
throw Exception("Maximum number of arguments for aggregate function with Nullable types is " + toString(size_t(MAX_ARGS)),
114+
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
115+
116+
for (size_t i = 0; i < number_of_arguments; ++i)
117+
is_nullable[i] = arguments[i]->isNullable();
118+
}
119+
120+
static inline bool singleFilter(const IColumn ** columns, size_t row_num, size_t num_arguments)
121+
{
122+
return assert_cast<const ColumnUInt8 &>(*columns[num_arguments - 1]).getData()[row_num];
123+
}
124+
125+
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
126+
{
127+
/// This container stores the columns we really pass to the nested function.
128+
const IColumn * nested_columns[number_of_arguments];
129+
130+
for (size_t i = 0; i < number_of_arguments; ++i)
131+
{
132+
if (is_nullable[i])
133+
{
134+
const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[i]);
135+
if (null_is_skipped && nullable_col.isNullAt(row_num))
136+
{
137+
/// If at least one column has a null value in the current row,
138+
/// we don't process this row.
139+
return;
140+
}
141+
nested_columns[i] = &nullable_col.getNestedColumn();
142+
}
143+
else
144+
nested_columns[i] = columns[i];
145+
}
146+
147+
if (singleFilter(nested_columns, row_num, number_of_arguments))
148+
{
149+
this->setFlag(place);
150+
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
151+
}
152+
}
153+
154+
private:
155+
using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,
156+
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>;
157+
158+
enum { MAX_ARGS = 8 };
159+
size_t number_of_arguments = 0;
160+
std::array<char, MAX_ARGS> is_nullable; /// Plain array is better than std::vector due to one indirection less.
161+
};
162+
163+
164+
AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter(
165+
const AggregateFunctionPtr & nested_function, const DataTypes & arguments,
166+
const Array & params, const AggregateFunctionProperties & properties) const
167+
{
168+
bool return_type_is_nullable = !properties.returns_default_when_only_null && getReturnType()->canBeInsideNullable();
169+
size_t nullable_size = std::count_if(arguments.begin(), arguments.end(), [](const auto & element) { return element->isNullable(); });
170+
return_type_is_nullable &= nullable_size != 1 || !arguments.back()->isNullable(); /// If only condition is nullable. we should non-nullable type.
171+
bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;
172+
173+
if (arguments.size() <= 2 && arguments.front()->isNullable())
174+
{
175+
if (return_type_is_nullable)
176+
{
177+
return std::make_shared<AggregateFunctionIfNullUnary<true, true>>(nested_func, arguments, params);
178+
}
179+
else
180+
{
181+
if (serialize_flag)
182+
return std::make_shared<AggregateFunctionIfNullUnary<false, true>>(nested_func, arguments, params);
183+
else
184+
return std::make_shared<AggregateFunctionIfNullUnary<false, false>>(nested_func, arguments, params);
185+
}
186+
}
187+
else
188+
{
189+
if (return_type_is_nullable)
190+
{
191+
return std::make_shared<AggregateFunctionIfNullVariadic<true, true, true>>(nested_function, arguments, params);
192+
}
193+
else
194+
{
195+
if (serialize_flag)
196+
return std::make_shared<AggregateFunctionIfNullVariadic<false, true, true>>(nested_function, arguments, params);
197+
else
198+
return std::make_shared<AggregateFunctionIfNullVariadic<false, true, false>>(nested_function, arguments, params);
199+
}
200+
}
201+
}
202+
43203
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory & factory)
44204
{
45205
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorIf>());

src/AggregateFunctions/AggregateFunctionIf.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ class AggregateFunctionIf final : public IAggregateFunctionHelper<AggregateFunct
109109
{
110110
return nested_func->isState();
111111
}
112+
113+
AggregateFunctionPtr getOwnNullAdapter(
114+
const AggregateFunctionPtr & nested_function, const DataTypes & arguments,
115+
const Array & params, const AggregateFunctionProperties & properties) const override;
112116
};
113117

114118
}

src/AggregateFunctions/AggregateFunctionNull.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinato
6969

7070
assert(nested_function);
7171

72-
if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params))
72+
if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params, properties))
7373
return adapter;
7474

7575
/// If applied to aggregate function with -State combinator, we apply -Null combinator to it's nested_function instead of itself.

src/AggregateFunctions/AggregateFunctionWindowFunnel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ class AggregateFunctionWindowFunnel final
241241
}
242242

243243
AggregateFunctionPtr getOwnNullAdapter(
244-
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
244+
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params,
245+
const AggregateFunctionProperties & /*properties*/) const override
245246
{
246247
return std::make_shared<AggregateFunctionNullVariadic<false, false, false>>(nested_function, arguments, params);
247248
}

src/AggregateFunctions/IAggregateFunction.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ using ConstAggregateDataPtr = const char *;
3333

3434
class IAggregateFunction;
3535
using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;
36+
struct AggregateFunctionProperties;
3637

3738
/** Aggregate functions interface.
3839
* Instances of classes with this interface do not contain the data itself for aggregation,
@@ -186,7 +187,8 @@ class IAggregateFunction
186187
* arguments and params are for nested_function.
187188
*/
188189
virtual AggregateFunctionPtr getOwnNullAdapter(
189-
const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const
190+
const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/,
191+
const Array & /*params*/, const AggregateFunctionProperties & /*properties*/) const
190192
{
191193
return nullptr;
192194
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
\N Nullable(UInt8)
2+
\N Nullable(UInt8)
3+
0 UInt8
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-- Value nullable
2+
SELECT anyIf(CAST(number, 'Nullable(UInt8)'), number = 3) AS a, toTypeName(a) FROM numbers(2);
3+
-- Value and condition nullable
4+
SELECT anyIf(number, number = 3) AS a, toTypeName(a) FROM (SELECT CAST(number, 'Nullable(UInt8)') AS number FROM numbers(2));
5+
-- Condition nullable
6+
SELECT anyIf(CAST(number, 'UInt8'), number = 3) AS a, toTypeName(a) FROM (SELECT CAST(number, 'Nullable(UInt8)') AS number FROM numbers(2));

0 commit comments

Comments
 (0)