|
1 | 1 | #include <AggregateFunctions/AggregateFunctionIf.h> |
2 | 2 | #include <AggregateFunctions/AggregateFunctionCombinatorFactory.h> |
3 | 3 | #include "registerAggregateFunctions.h" |
| 4 | +#include "AggregateFunctionNull.h" |
4 | 5 |
|
5 | 6 |
|
6 | 7 | namespace DB |
7 | 8 | { |
8 | 9 |
|
9 | 10 | namespace ErrorCodes |
10 | 11 | { |
| 12 | + extern const int LOGICAL_ERROR; |
11 | 13 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
12 | 14 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
13 | 15 | } |
@@ -40,6 +42,164 @@ class AggregateFunctionCombinatorIf final : public IAggregateFunctionCombinator |
40 | 42 | } |
41 | 43 | }; |
42 | 44 |
|
| 45 | +/** There are two cases: for single argument and variadic. |
| 46 | + * Code for single argument is much more efficient. |
| 47 | + */ |
| 48 | +template <bool result_is_nullable, bool serialize_flag> |
| 49 | +class AggregateFunctionIfNullUnary final |
| 50 | + : public AggregateFunctionNullBase<result_is_nullable, serialize_flag, |
| 51 | + AggregateFunctionIfNullUnary<result_is_nullable, serialize_flag>> |
| 52 | +{ |
| 53 | +private: |
| 54 | + size_t num_arguments; |
| 55 | + |
| 56 | + using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag, |
| 57 | + AggregateFunctionIfNullUnary<result_is_nullable, serialize_flag>>; |
| 58 | +public: |
| 59 | + |
| 60 | + String getName() const override |
| 61 | + { |
| 62 | + return Base::getName() + "If"; |
| 63 | + } |
| 64 | + |
| 65 | + AggregateFunctionIfNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) |
| 66 | + : Base(std::move(nested_function_), arguments, params), num_arguments(arguments.size()) |
| 67 | + { |
| 68 | + if (num_arguments == 0) |
| 69 | + throw Exception("Aggregate function " + getName() + " require at least one argument", |
| 70 | + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
| 71 | + } |
| 72 | + |
| 73 | + static inline bool singleFilter(const IColumn ** columns, size_t row_num, size_t num_arguments) |
| 74 | + { |
| 75 | + const IColumn * filter_column = columns[num_arguments - 1]; |
| 76 | + if (const ColumnNullable * nullable_column = typeid_cast<const ColumnNullable *>(filter_column)) |
| 77 | + filter_column = nullable_column->getNestedColumnPtr().get(); |
| 78 | + |
| 79 | + return assert_cast<const ColumnUInt8 &>(*filter_column).getData()[row_num]; |
| 80 | + } |
| 81 | + |
| 82 | + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override |
| 83 | + { |
| 84 | + const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]); |
| 85 | + const IColumn * nested_column = &column->getNestedColumn(); |
| 86 | + if (!column->isNullAt(row_num) && singleFilter(columns, row_num, num_arguments)) |
| 87 | + { |
| 88 | + this->setFlag(place); |
| 89 | + this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); |
| 90 | + } |
| 91 | + } |
| 92 | +}; |
| 93 | + |
| 94 | +template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped> |
| 95 | +class AggregateFunctionIfNullVariadic final |
| 96 | + : public AggregateFunctionNullBase<result_is_nullable, serialize_flag, |
| 97 | + AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>> |
| 98 | +{ |
| 99 | +public: |
| 100 | + |
| 101 | + String getName() const override |
| 102 | + { |
| 103 | + return Base::getName() + "If"; |
| 104 | + } |
| 105 | + |
| 106 | + AggregateFunctionIfNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) |
| 107 | + : Base(std::move(nested_function_), arguments, params), number_of_arguments(arguments.size()) |
| 108 | + { |
| 109 | + if (number_of_arguments == 1) |
| 110 | + throw Exception("Logical error: single argument is passed to AggregateFunctionIfNullVariadic", ErrorCodes::LOGICAL_ERROR); |
| 111 | + |
| 112 | + if (number_of_arguments > MAX_ARGS) |
| 113 | + throw Exception("Maximum number of arguments for aggregate function with Nullable types is " + toString(size_t(MAX_ARGS)), |
| 114 | + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
| 115 | + |
| 116 | + for (size_t i = 0; i < number_of_arguments; ++i) |
| 117 | + is_nullable[i] = arguments[i]->isNullable(); |
| 118 | + } |
| 119 | + |
| 120 | + static inline bool singleFilter(const IColumn ** columns, size_t row_num, size_t num_arguments) |
| 121 | + { |
| 122 | + return assert_cast<const ColumnUInt8 &>(*columns[num_arguments - 1]).getData()[row_num]; |
| 123 | + } |
| 124 | + |
| 125 | + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override |
| 126 | + { |
| 127 | + /// This container stores the columns we really pass to the nested function. |
| 128 | + const IColumn * nested_columns[number_of_arguments]; |
| 129 | + |
| 130 | + for (size_t i = 0; i < number_of_arguments; ++i) |
| 131 | + { |
| 132 | + if (is_nullable[i]) |
| 133 | + { |
| 134 | + const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[i]); |
| 135 | + if (null_is_skipped && nullable_col.isNullAt(row_num)) |
| 136 | + { |
| 137 | + /// If at least one column has a null value in the current row, |
| 138 | + /// we don't process this row. |
| 139 | + return; |
| 140 | + } |
| 141 | + nested_columns[i] = &nullable_col.getNestedColumn(); |
| 142 | + } |
| 143 | + else |
| 144 | + nested_columns[i] = columns[i]; |
| 145 | + } |
| 146 | + |
| 147 | + if (singleFilter(nested_columns, row_num, number_of_arguments)) |
| 148 | + { |
| 149 | + this->setFlag(place); |
| 150 | + this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); |
| 151 | + } |
| 152 | + } |
| 153 | + |
| 154 | +private: |
| 155 | + using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag, |
| 156 | + AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>; |
| 157 | + |
| 158 | + enum { MAX_ARGS = 8 }; |
| 159 | + size_t number_of_arguments = 0; |
| 160 | + std::array<char, MAX_ARGS> is_nullable; /// Plain array is better than std::vector due to one indirection less. |
| 161 | +}; |
| 162 | + |
| 163 | + |
| 164 | +AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter( |
| 165 | + const AggregateFunctionPtr & nested_function, const DataTypes & arguments, |
| 166 | + const Array & params, const AggregateFunctionProperties & properties) const |
| 167 | +{ |
| 168 | + bool return_type_is_nullable = !properties.returns_default_when_only_null && getReturnType()->canBeInsideNullable(); |
| 169 | + size_t nullable_size = std::count_if(arguments.begin(), arguments.end(), [](const auto & element) { return element->isNullable(); }); |
| 170 | + return_type_is_nullable &= nullable_size != 1 || !arguments.back()->isNullable(); /// If only condition is nullable. we should non-nullable type. |
| 171 | + bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null; |
| 172 | + |
| 173 | + if (arguments.size() <= 2 && arguments.front()->isNullable()) |
| 174 | + { |
| 175 | + if (return_type_is_nullable) |
| 176 | + { |
| 177 | + return std::make_shared<AggregateFunctionIfNullUnary<true, true>>(nested_func, arguments, params); |
| 178 | + } |
| 179 | + else |
| 180 | + { |
| 181 | + if (serialize_flag) |
| 182 | + return std::make_shared<AggregateFunctionIfNullUnary<false, true>>(nested_func, arguments, params); |
| 183 | + else |
| 184 | + return std::make_shared<AggregateFunctionIfNullUnary<false, false>>(nested_func, arguments, params); |
| 185 | + } |
| 186 | + } |
| 187 | + else |
| 188 | + { |
| 189 | + if (return_type_is_nullable) |
| 190 | + { |
| 191 | + return std::make_shared<AggregateFunctionIfNullVariadic<true, true, true>>(nested_function, arguments, params); |
| 192 | + } |
| 193 | + else |
| 194 | + { |
| 195 | + if (serialize_flag) |
| 196 | + return std::make_shared<AggregateFunctionIfNullVariadic<false, true, true>>(nested_function, arguments, params); |
| 197 | + else |
| 198 | + return std::make_shared<AggregateFunctionIfNullVariadic<false, true, false>>(nested_function, arguments, params); |
| 199 | + } |
| 200 | + } |
| 201 | +} |
| 202 | + |
43 | 203 | void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory & factory) |
44 | 204 | { |
45 | 205 | factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorIf>()); |
|
0 commit comments