Skip to content

Commit 6481443

Browse files
Abseil Teamcopybara-github
authored andcommitted
Optimize SwissMap for ARM by 3-8% for all operations
https://pastebin.com/CmnzwUFN The key idea is to avoid using 16 byte NEON and use 8 byte NEON which has lower latency for BitMask::Match. Even though 16 byte NEON achieves higher throughput, in SwissMap it's very important to catch these Matches with low latency as probing on average happens at most once. I also introduced NonIterableMask as ARM has really great cbnz instructions and additional AND on scalar mask had 1 extra latency cycle PiperOrigin-RevId: 453216147 Change-Id: I842c50d323954f8383ae156491232ced55aacb78
1 parent 4841959 commit 6481443

4 files changed

Lines changed: 176 additions & 98 deletions

File tree

absl/base/config.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,4 +898,13 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' ||
898898
#define ABSL_INTERNAL_HAVE_ARM_ACLE 1
899899
#endif
900900

901+
// ABSL_INTERNAL_HAVE_ARM_NEON is used for compile-time detection of NEON (ARM
902+
// SIMD).
903+
#ifdef ABSL_INTERNAL_HAVE_ARM_NEON
904+
#error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set
905+
#elif defined(__ARM_NEON)
906+
#define ABSL_INTERNAL_HAVE_ARM_NEON 1
907+
#endif
908+
909+
901910
#endif // ABSL_BASE_CONFIG_H_

absl/container/internal/raw_hash_set.h

Lines changed: 146 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,14 @@
184184
#include <intrin.h>
185185
#endif
186186

187+
#ifdef __ARM_NEON
188+
#include <arm_neon.h>
189+
#endif
190+
191+
#ifdef __ARM_ACLE
192+
#include <arm_acle.h>
193+
#endif
194+
187195
#include <algorithm>
188196
#include <cmath>
189197
#include <cstdint>
@@ -211,10 +219,6 @@
211219
#include "absl/numeric/bits.h"
212220
#include "absl/utility/utility.h"
213221

214-
#ifdef ABSL_INTERNAL_HAVE_ARM_ACLE
215-
#include <arm_acle.h>
216-
#endif
217-
218222
namespace absl {
219223
ABSL_NAMESPACE_BEGIN
220224
namespace container_internal {
@@ -323,36 +327,15 @@ uint32_t TrailingZeros(T x) {
323327
// controlled by `SignificantBits` and `Shift`. `SignificantBits` is the number
324328
// of abstract bits in the bitset, while `Shift` is the log-base-two of the
325329
// width of an abstract bit in the representation.
326-
//
327-
// For example, when `SignificantBits` is 16 and `Shift` is zero, this is just
328-
// an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When
329-
// `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as
330-
// the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask.
331-
//
332-
// For example:
333-
// for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2
334-
// for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3
330+
// This mask provides operations for any number of real bits set in an abstract
331+
// bit. To add iteration on top of that, implementation must guarantee no more
332+
// than one real bit is set in an abstract bit.
335333
template <class T, int SignificantBits, int Shift = 0>
336-
class BitMask {
337-
static_assert(std::is_unsigned<T>::value, "");
338-
static_assert(Shift == 0 || Shift == 3, "");
339-
334+
class NonIterableBitMask {
340335
public:
341-
// BitMask is an iterator over the indices of its abstract bits.
342-
using value_type = int;
343-
using iterator = BitMask;
344-
using const_iterator = BitMask;
345-
346-
explicit BitMask(T mask) : mask_(mask) {}
347-
BitMask& operator++() {
348-
mask_ &= (mask_ - 1);
349-
return *this;
350-
}
351-
explicit operator bool() const { return mask_ != 0; }
352-
uint32_t operator*() const { return LowestBitSet(); }
336+
explicit NonIterableBitMask(T mask) : mask_(mask) {}
353337

354-
BitMask begin() const { return *this; }
355-
BitMask end() const { return BitMask(0); }
338+
explicit operator bool() const { return this->mask_ != 0; }
356339

357340
// Returns the index of the lowest *abstract* bit set in `self`.
358341
uint32_t LowestBitSet() const {
@@ -376,15 +359,49 @@ class BitMask {
376359
return static_cast<uint32_t>(countl_zero(mask_ << extra_bits)) >> Shift;
377360
}
378361

362+
T mask_;
363+
};
364+
365+
// Mask that can be iterable
366+
//
367+
// For example, when `SignificantBits` is 16 and `Shift` is zero, this is just
368+
// an ordinary 16-bit bitset occupying the low 16 bits of `mask`. When
369+
// `SignificantBits` is 8 and `Shift` is 3, abstract bits are represented as
370+
// the bytes `0x00` and `0x80`, and it occupies all 64 bits of the bitmask.
371+
//
372+
// For example:
373+
// for (int i : BitMask<uint32_t, 16>(0b101)) -> yields 0, 2
374+
// for (int i : BitMask<uint64_t, 8, 3>(0x0000000080800000)) -> yields 2, 3
375+
template <class T, int SignificantBits, int Shift = 0>
376+
class BitMask : public NonIterableBitMask<T, SignificantBits, Shift> {
377+
using Base = NonIterableBitMask<T, SignificantBits, Shift>;
378+
static_assert(std::is_unsigned<T>::value, "");
379+
static_assert(Shift == 0 || Shift == 3, "");
380+
381+
public:
382+
explicit BitMask(T mask) : Base(mask) {}
383+
// BitMask is an iterator over the indices of its abstract bits.
384+
using value_type = int;
385+
using iterator = BitMask;
386+
using const_iterator = BitMask;
387+
388+
BitMask& operator++() {
389+
this->mask_ &= (this->mask_ - 1);
390+
return *this;
391+
}
392+
393+
uint32_t operator*() const { return Base::LowestBitSet(); }
394+
395+
BitMask begin() const { return *this; }
396+
BitMask end() const { return BitMask(0); }
397+
379398
private:
380399
friend bool operator==(const BitMask& a, const BitMask& b) {
381400
return a.mask_ == b.mask_;
382401
}
383402
friend bool operator!=(const BitMask& a, const BitMask& b) {
384403
return a.mask_ != b.mask_;
385404
}
386-
387-
T mask_;
388405
};
389406

390407
using h2_t = uint8_t;
@@ -433,7 +450,7 @@ static_assert(
433450
static_cast<int8_t>(ctrl_t::kSentinel) & 0x7F) != 0,
434451
"ctrl_t::kEmpty and ctrl_t::kDeleted must share an unset bit that is not "
435452
"shared by ctrl_t::kSentinel to make the scalar test for "
436-
"MatchEmptyOrDeleted() efficient");
453+
"MaskEmptyOrDeleted() efficient");
437454
static_assert(ctrl_t::kDeleted == static_cast<ctrl_t>(-2),
438455
"ctrl_t::kDeleted must be -2 to make the implementation of "
439456
"ConvertSpecialToEmptyAndFullToDeleted efficient");
@@ -538,20 +555,22 @@ struct GroupSse2Impl {
538555
}
539556

540557
// Returns a bitmask representing the positions of empty slots.
541-
BitMask<uint32_t, kWidth> MatchEmpty() const {
558+
NonIterableBitMask<uint32_t, kWidth> MaskEmpty() const {
542559
#ifdef ABSL_INTERNAL_HAVE_SSSE3
543560
// This only works because ctrl_t::kEmpty is -128.
544-
return BitMask<uint32_t, kWidth>(
561+
return NonIterableBitMask<uint32_t, kWidth>(
545562
static_cast<uint32_t>(_mm_movemask_epi8(_mm_sign_epi8(ctrl, ctrl))));
546563
#else
547-
return Match(static_cast<h2_t>(ctrl_t::kEmpty));
564+
auto match = _mm_set1_epi8(static_cast<h2_t>(ctrl_t::kEmpty));
565+
return NonIterableBitMask<uint32_t, kWidth>(
566+
static_cast<uint32_t>(_mm_movemask_epi8(_mm_cmpeq_epi8(match, ctrl))));
548567
#endif
549568
}
550569

551570
// Returns a bitmask representing the positions of empty or deleted slots.
552-
BitMask<uint32_t, kWidth> MatchEmptyOrDeleted() const {
571+
NonIterableBitMask<uint32_t, kWidth> MaskEmptyOrDeleted() const {
553572
auto special = _mm_set1_epi8(static_cast<uint8_t>(ctrl_t::kSentinel));
554-
return BitMask<uint32_t, kWidth>(static_cast<uint32_t>(
573+
return NonIterableBitMask<uint32_t, kWidth>(static_cast<uint32_t>(
555574
_mm_movemask_epi8(_mm_cmpgt_epi8_fixed(special, ctrl))));
556575
}
557576

@@ -579,6 +598,80 @@ struct GroupSse2Impl {
579598
};
580599
#endif // ABSL_INTERNAL_RAW_HASH_SET_HAVE_SSE2
581600

601+
#if defined(ABSL_INTERNAL_HAVE_ARM_NEON) && defined(ABSL_IS_LITTLE_ENDIAN)
602+
struct GroupAArch64Impl {
603+
static constexpr size_t kWidth = 8;
604+
605+
explicit GroupAArch64Impl(const ctrl_t* pos) {
606+
ctrl = vld1_u8(reinterpret_cast<const uint8_t*>(pos));
607+
}
608+
609+
BitMask<uint64_t, kWidth, 3> Match(h2_t hash) const {
610+
uint8x8_t dup = vdup_n_u8(hash);
611+
auto mask = vceq_u8(ctrl, dup);
612+
constexpr uint64_t msbs = 0x8080808080808080ULL;
613+
return BitMask<uint64_t, kWidth, 3>(
614+
vget_lane_u64(vreinterpret_u64_u8(mask), 0) & msbs);
615+
}
616+
617+
NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const {
618+
uint64_t mask =
619+
vget_lane_u64(vreinterpret_u64_u8(
620+
vceq_s8(vdup_n_s8(static_cast<h2_t>(ctrl_t::kEmpty)),
621+
vreinterpret_s8_u8(ctrl))),
622+
0);
623+
return NonIterableBitMask<uint64_t, kWidth, 3>(mask);
624+
}
625+
626+
NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const {
627+
uint64_t mask =
628+
vget_lane_u64(vreinterpret_u64_u8(vcgt_s8(
629+
vdup_n_s8(static_cast<int8_t>(ctrl_t::kSentinel)),
630+
vreinterpret_s8_u8(ctrl))),
631+
0);
632+
return NonIterableBitMask<uint64_t, kWidth, 3>(mask);
633+
}
634+
635+
uint32_t CountLeadingEmptyOrDeleted() const {
636+
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0);
637+
assert(IsEmptyOrDeleted(static_cast<ctrl_t>(mask & 0xff)));
638+
constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL;
639+
#if defined(ABSL_INTERNAL_HAVE_ARM_ACLE)
640+
// cls: Count leading sign bits.
641+
// clsll(1ull << 63) -> 0
642+
// clsll((1ull << 63) | (1ull << 62)) -> 1
643+
// clsll((1ull << 63) | (1ull << 61)) -> 0
644+
// clsll(~0ull) -> 63
645+
// clsll(1) -> 62
646+
// clsll(3) -> 61
647+
// clsll(5) -> 60
648+
// Note that CountLeadingEmptyOrDeleted is called when first control block
649+
// is kDeleted or kEmpty. The implementation is similar to GroupPortableImpl
650+
// but avoids +1 and __clsll returns result not including the high bit. Thus
651+
// saves one cycle.
652+
// kEmpty = -128, // 0b10000000
653+
// kDeleted = -2, // 0b11111110
654+
// ~ctrl & (ctrl >> 7) will have the lowest bit set to 1. After rbit,
655+
// it will the highest one.
656+
return (__clsll(__rbitll((~mask & (mask >> 7)) | gaps)) + 8) >> 3;
657+
#else
658+
return (TrailingZeros(((~mask & (mask >> 7)) | gaps) + 1) + 7) >> 3;
659+
#endif // ABSL_INTERNAL_HAVE_ARM_ACLE
660+
}
661+
662+
void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const {
663+
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(ctrl), 0);
664+
constexpr uint64_t msbs = 0x8080808080808080ULL;
665+
constexpr uint64_t lsbs = 0x0101010101010101ULL;
666+
auto x = mask & msbs;
667+
auto res = (~x + (x >> 7)) & ~lsbs;
668+
little_endian::Store64(dst, res);
669+
}
670+
671+
uint8x8_t ctrl;
672+
};
673+
#endif // ABSL_INTERNAL_HAVE_ARM_NEON && ABSL_IS_LITTLE_ENDIAN
674+
582675
struct GroupPortableImpl {
583676
static constexpr size_t kWidth = 8;
584677

@@ -605,14 +698,16 @@ struct GroupPortableImpl {
605698
return BitMask<uint64_t, kWidth, 3>((x - lsbs) & ~x & msbs);
606699
}
607700

608-
BitMask<uint64_t, kWidth, 3> MatchEmpty() const {
701+
NonIterableBitMask<uint64_t, kWidth, 3> MaskEmpty() const {
609702
constexpr uint64_t msbs = 0x8080808080808080ULL;
610-
return BitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 6)) & msbs);
703+
return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 6)) &
704+
msbs);
611705
}
612706

613-
BitMask<uint64_t, kWidth, 3> MatchEmptyOrDeleted() const {
707+
NonIterableBitMask<uint64_t, kWidth, 3> MaskEmptyOrDeleted() const {
614708
constexpr uint64_t msbs = 0x8080808080808080ULL;
615-
return BitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 7)) & msbs);
709+
return NonIterableBitMask<uint64_t, kWidth, 3>((ctrl & (~ctrl << 7)) &
710+
msbs);
616711
}
617712

618713
uint32_t CountLeadingEmptyOrDeleted() const {
@@ -631,39 +726,9 @@ struct GroupPortableImpl {
631726
uint64_t ctrl;
632727
};
633728

634-
#ifdef ABSL_INTERNAL_HAVE_ARM_ACLE
635-
struct GroupAArch64Impl : public GroupPortableImpl {
636-
static constexpr size_t kWidth = GroupPortableImpl::kWidth;
637-
638-
using GroupPortableImpl::GroupPortableImpl;
639-
640-
uint32_t CountLeadingEmptyOrDeleted() const {
641-
assert(IsEmptyOrDeleted(static_cast<ctrl_t>(ctrl & 0xff)));
642-
constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL;
643-
// cls: Count leading sign bits.
644-
// clsll(1ull << 63) -> 0
645-
// clsll((1ull << 63) | (1ull << 62)) -> 1
646-
// clsll((1ull << 63) | (1ull << 61)) -> 0
647-
// clsll(~0ull) -> 63
648-
// clsll(1) -> 62
649-
// clsll(3) -> 61
650-
// clsll(5) -> 60
651-
// Note that CountLeadingEmptyOrDeleted is called when first control block
652-
// is kDeleted or kEmpty. The implementation is similar to GroupPortableImpl
653-
// but avoids +1 and __clsll returns result not including the high bit. Thus
654-
// saves one cycle.
655-
// kEmpty = -128, // 0b10000000
656-
// kDeleted = -2, // 0b11111110
657-
// ~ctrl & (ctrl >> 7) will have the lowest bit set to 1. After rbit,
658-
// it will the highest one.
659-
return (__clsll(__rbitll((~ctrl & (ctrl >> 7)) | gaps)) + 8) >> 3;
660-
}
661-
};
662-
#endif
663-
664729
#ifdef ABSL_INTERNAL_HAVE_SSE2
665730
using Group = GroupSse2Impl;
666-
#elif defined(ABSL_INTERNAL_HAVE_ARM_ACLE)
731+
#elif defined(ABSL_INTERNAL_HAVE_ARM_NEON) && defined(ABSL_IS_LITTLE_ENDIAN)
667732
using Group = GroupAArch64Impl;
668733
#else
669734
using Group = GroupPortableImpl;
@@ -798,7 +863,7 @@ inline FindInfo find_first_non_full(const ctrl_t* ctrl, size_t hash,
798863
auto seq = probe(ctrl, hash, capacity);
799864
while (true) {
800865
Group g{ctrl + seq.offset()};
801-
auto mask = g.MatchEmptyOrDeleted();
866+
auto mask = g.MaskEmptyOrDeleted();
802867
if (mask) {
803868
#if !defined(NDEBUG)
804869
// We want to add entropy even when ASLR is not enabled.
@@ -1700,7 +1765,7 @@ class raw_hash_set {
17001765
PolicyTraits::element(slots_ + seq.offset(i)))))
17011766
return iterator_at(seq.offset(i));
17021767
}
1703-
if (ABSL_PREDICT_TRUE(g.MatchEmpty())) return end();
1768+
if (ABSL_PREDICT_TRUE(g.MaskEmpty())) return end();
17041769
seq.next();
17051770
assert(seq.index() <= capacity_ && "full table!");
17061771
}
@@ -1849,8 +1914,8 @@ class raw_hash_set {
18491914
--size_;
18501915
const size_t index = static_cast<size_t>(it.inner_.ctrl_ - ctrl_);
18511916
const size_t index_before = (index - Group::kWidth) & capacity_;
1852-
const auto empty_after = Group(it.inner_.ctrl_).MatchEmpty();
1853-
const auto empty_before = Group(ctrl_ + index_before).MatchEmpty();
1917+
const auto empty_after = Group(it.inner_.ctrl_).MaskEmpty();
1918+
const auto empty_before = Group(ctrl_ + index_before).MaskEmpty();
18541919

18551920
// We count how many consecutive non empties we have to the right and to the
18561921
// left of `it`. If the sum is >= kWidth then there is at least one probe
@@ -2091,7 +2156,7 @@ class raw_hash_set {
20912156
elem))
20922157
return true;
20932158
}
2094-
if (ABSL_PREDICT_TRUE(g.MatchEmpty())) return false;
2159+
if (ABSL_PREDICT_TRUE(g.MaskEmpty())) return false;
20952160
seq.next();
20962161
assert(seq.index() <= capacity_ && "full table!");
20972162
}
@@ -2127,7 +2192,7 @@ class raw_hash_set {
21272192
PolicyTraits::element(slots_ + seq.offset(i)))))
21282193
return {seq.offset(i), false};
21292194
}
2130-
if (ABSL_PREDICT_TRUE(g.MatchEmpty())) break;
2195+
if (ABSL_PREDICT_TRUE(g.MaskEmpty())) break;
21312196
seq.next();
21322197
assert(seq.index() <= capacity_ && "full table!");
21332198
}
@@ -2272,7 +2337,7 @@ struct HashtableDebugAccess<Set, absl::void_t<typename Set::raw_hash_set>> {
22722337
return num_probes;
22732338
++num_probes;
22742339
}
2275-
if (g.MatchEmpty()) return num_probes;
2340+
if (g.MaskEmpty()) return num_probes;
22762341
seq.next();
22772342
++num_probes;
22782343
}

0 commit comments

Comments
 (0)