|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <algorithm> |
| 4 | +#include <vector> |
| 5 | + |
| 6 | +#include <base/types.h> |
| 7 | +#include <Common/Exception.h> |
| 8 | +#include <Common/Volnitsky.h> |
| 9 | +#include <Columns/ColumnString.h> |
| 10 | + |
| 11 | + |
| 12 | +namespace DB |
| 13 | +{ |
| 14 | + |
| 15 | +namespace ErrorCodes |
| 16 | +{ |
| 17 | + extern const int LIMIT_EXCEEDED; |
| 18 | +} |
| 19 | + |
| 20 | +struct HighlightImpl |
| 21 | +{ |
| 22 | + static constexpr size_t DEFAULT_MAX_MATCHES_PER_ROW = 10000; |
| 23 | + struct Interval |
| 24 | + { |
| 25 | + size_t begin; |
| 26 | + size_t end; |
| 27 | + }; |
| 28 | + |
| 29 | + /// Sort and merge overlapping/adjacent intervals in-place. |
| 30 | + /// Uses <= for merge condition so that adjacent intervals like [0,5)+[5,10) merge into [0,10). |
| 31 | + static void mergeIntervals(std::vector<Interval> & intervals) |
| 32 | + { |
| 33 | + if (intervals.size() <= 1) |
| 34 | + return; |
| 35 | + |
| 36 | + std::sort(intervals.begin(), intervals.end(), [](const Interval & a, const Interval & b) |
| 37 | + { |
| 38 | + return a.begin < b.begin || (a.begin == b.begin && a.end > b.end); |
| 39 | + }); |
| 40 | + |
| 41 | + size_t write = 0; |
| 42 | + for (size_t read = 1; read < intervals.size(); ++read) |
| 43 | + { |
| 44 | + if (intervals[read].begin <= intervals[write].end) |
| 45 | + intervals[write].end = std::max(intervals[write].end, intervals[read].end); |
| 46 | + else |
| 47 | + intervals[++write] = intervals[read]; |
| 48 | + } |
| 49 | + intervals.resize(write + 1); |
| 50 | + } |
| 51 | + |
| 52 | + struct NeedleSearcher |
| 53 | + { |
| 54 | + VolnitskyCaseInsensitive searcher; |
| 55 | + size_t needle_size; |
| 56 | + }; |
| 57 | + |
| 58 | + static void execute( |
| 59 | + const ColumnString::Chars & haystack_data, |
| 60 | + const ColumnString::Offsets & haystack_offsets, |
| 61 | + const std::vector<std::string_view> & needles, |
| 62 | + const String & open_tag, |
| 63 | + const String & close_tag, |
| 64 | + ColumnString::Chars & res_data, |
| 65 | + ColumnString::Offsets & res_offsets, |
| 66 | + size_t input_rows_count, |
| 67 | + UInt64 max_matches_per_row = DEFAULT_MAX_MATCHES_PER_ROW) |
| 68 | + { |
| 69 | + /// Pre-allocate output buffers — conservative estimate to avoid over-allocation |
| 70 | + /// with many needles: at most one tag pair per row on average. |
| 71 | + const size_t tag_overhead = open_tag.size() + close_tag.size(); |
| 72 | + res_data.reserve(haystack_data.size() + input_rows_count * tag_overhead); |
| 73 | + res_offsets.resize(input_rows_count); |
| 74 | + |
| 75 | + /// Build searcher instances once outside the row loop, paired with needle sizes. |
| 76 | + /// We use VolnitskyCaseInsensitive with haystack_size_hint=0, which means |
| 77 | + /// each search() call decides internally whether to use the hash table |
| 78 | + /// or fall back to ASCIICaseInsensitiveStringSearcher for short haystacks. |
| 79 | + std::vector<NeedleSearcher> searchers; |
| 80 | + searchers.reserve(needles.size()); |
| 81 | + for (const auto & needle : needles) |
| 82 | + if (!needle.empty()) |
| 83 | + searchers.push_back({VolnitskyCaseInsensitive(needle.data(), needle.size(), 0), needle.size()}); |
| 84 | + |
| 85 | + /// Reusable intervals buffer across rows |
| 86 | + std::vector<Interval> intervals; |
| 87 | + intervals.reserve(64); |
| 88 | + |
| 89 | + ColumnString::Offset res_offset = 0; |
| 90 | + ColumnString::Offset prev_haystack_offset = 0; |
| 91 | + |
| 92 | + for (size_t i = 0; i < input_rows_count; ++i) |
| 93 | + { |
| 94 | + const size_t cur_size = haystack_offsets[i] - prev_haystack_offset; |
| 95 | + |
| 96 | + if (cur_size > 0) |
| 97 | + { |
| 98 | + const UInt8 * cur_data = &haystack_data[prev_haystack_offset]; |
| 99 | + |
| 100 | + /// Phase 1: find all matches |
| 101 | + intervals.clear(); |
| 102 | + findAllMatches(cur_data, cur_size, searchers, intervals, max_matches_per_row); |
| 103 | + |
| 104 | + if (intervals.empty()) |
| 105 | + { |
| 106 | + /// No matches — copy as-is |
| 107 | + append(res_data, res_offset, cur_data, cur_size); |
| 108 | + } |
| 109 | + else |
| 110 | + { |
| 111 | + /// Phase 2: merge overlapping intervals |
| 112 | + mergeIntervals(intervals); |
| 113 | + |
| 114 | + /// Phase 3: build output with tags |
| 115 | + buildOutput(cur_data, cur_size, intervals, open_tag, close_tag, res_data, res_offset); |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + res_offsets[i] = res_offset; |
| 120 | + prev_haystack_offset = haystack_offsets[i]; |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | +private: |
| 125 | + /// Phase 1: For each needle, find all occurrence positions in the haystack. |
| 126 | + static void findAllMatches( |
| 127 | + const UInt8 * haystack, |
| 128 | + size_t haystack_size, |
| 129 | + const std::vector<NeedleSearcher> & searchers, |
| 130 | + std::vector<Interval> & intervals, |
| 131 | + UInt64 max_matches_per_row) |
| 132 | + { |
| 133 | + const UInt8 * haystack_end = haystack + haystack_size; |
| 134 | + |
| 135 | + for (const auto & [searcher, needle_size] : searchers) |
| 136 | + { |
| 137 | + const UInt8 * pos = haystack; |
| 138 | + while (pos < haystack_end) |
| 139 | + { |
| 140 | + const UInt8 * match = searcher.search(pos, haystack_end - pos); |
| 141 | + if (match == haystack_end) |
| 142 | + break; |
| 143 | + |
| 144 | + const size_t offset = match - haystack; |
| 145 | + intervals.push_back({offset, offset + needle_size}); |
| 146 | + pos = match + 1; |
| 147 | + |
| 148 | + if (intervals.size() > max_matches_per_row) |
| 149 | + throw Exception( |
| 150 | + ErrorCodes::LIMIT_EXCEEDED, |
| 151 | + "Too many highlight matches per row: {}, max: {}. " |
| 152 | + "You can increase this limit with the `highlight_max_matches_per_row` setting", |
| 153 | + intervals.size(), max_matches_per_row); |
| 154 | + } |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + /// Phase 3: Build the output string by interleaving non-matched text with tagged matched text. |
| 159 | + static void buildOutput( |
| 160 | + const UInt8 * haystack, |
| 161 | + size_t haystack_size, |
| 162 | + const std::vector<Interval> & intervals, |
| 163 | + const String & open_tag, |
| 164 | + const String & close_tag, |
| 165 | + ColumnString::Chars & res_data, |
| 166 | + ColumnString::Offset & res_offset) |
| 167 | + { |
| 168 | + size_t cursor = 0; |
| 169 | + for (const auto & interval : intervals) |
| 170 | + { |
| 171 | + /// Copy non-matched text before this interval |
| 172 | + if (interval.begin > cursor) |
| 173 | + append(res_data, res_offset, haystack + cursor, interval.begin - cursor); |
| 174 | + |
| 175 | + /// Insert open tag |
| 176 | + if (!open_tag.empty()) |
| 177 | + append(res_data, res_offset, reinterpret_cast<const UInt8 *>(open_tag.data()), open_tag.size()); |
| 178 | + |
| 179 | + /// Copy matched text (preserving original case) |
| 180 | + append(res_data, res_offset, haystack + interval.begin, interval.end - interval.begin); |
| 181 | + |
| 182 | + /// Insert close tag |
| 183 | + if (!close_tag.empty()) |
| 184 | + append(res_data, res_offset, reinterpret_cast<const UInt8 *>(close_tag.data()), close_tag.size()); |
| 185 | + |
| 186 | + cursor = interval.end; |
| 187 | + } |
| 188 | + |
| 189 | + /// Copy remaining text after the last interval |
| 190 | + if (cursor < haystack_size) |
| 191 | + append(res_data, res_offset, haystack + cursor, haystack_size - cursor); |
| 192 | + } |
| 193 | + |
| 194 | + static inline void append( |
| 195 | + ColumnString::Chars & data, |
| 196 | + ColumnString::Offset & offset, |
| 197 | + const void * src, |
| 198 | + size_t size) |
| 199 | + { |
| 200 | + data.resize(data.size() + size); |
| 201 | + memcpy(&data[offset], src, size); |
| 202 | + offset += size; |
| 203 | + } |
| 204 | +}; |
| 205 | + |
| 206 | +} |
0 commit comments