Skip to content

Commit 886f85e

Browse files
Auto merge of #150067 - fereidani:string_retain, r=<try>
Alloc `String::retain` optimization
2 parents d9617c8 + 590ba93 commit 886f85e

File tree

1 file changed

+42
-38
lines changed

1 file changed

+42
-38
lines changed

library/alloc/src/string.rs

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ use core::ops::Add;
5454
use core::ops::AddAssign;
5555
use core::ops::{self, Range, RangeBounds};
5656
use core::str::pattern::{Pattern, Utf8Pattern};
57-
use core::{fmt, hash, ptr, slice};
57+
use core::{fmt, hash, hint, ptr, slice};
5858

5959
#[cfg(not(no_global_oom_handling))]
6060
use crate::alloc::Allocator;
@@ -1645,53 +1645,57 @@ impl String {
16451645
where
16461646
F: FnMut(char) -> bool,
16471647
{
1648-
struct SetLenOnDrop<'a> {
1649-
s: &'a mut String,
1650-
idx: usize,
1651-
del_bytes: usize,
1648+
let len = self.len();
1649+
if len == 0 {
1650+
// Explicit check results in better optimization
1651+
return;
1652+
}
1653+
1654+
struct PanicGuard {
1655+
s: ptr::NonNull<String>,
1656+
read: usize,
1657+
write: usize,
16521658
}
16531659

1654-
impl<'a> Drop for SetLenOnDrop<'a> {
1660+
impl Drop for PanicGuard {
16551661
fn drop(&mut self) {
1656-
let new_len = self.idx - self.del_bytes;
1657-
debug_assert!(new_len <= self.s.len());
1658-
unsafe { self.s.vec.set_len(new_len) };
1662+
// SAFETY: This is guaranteed to be the only mutable reference to `s`.
1663+
let str = unsafe { &mut *self.s.as_ptr() };
1664+
debug_assert!(self.write <= str.len());
1665+
// SAFETY: Restore the string length to the number of bytes written so far.
1666+
unsafe { str.vec.set_len(self.write) }
16591667
}
16601668
}
16611669

1662-
let len = self.len();
1663-
let mut guard = SetLenOnDrop { s: self, idx: 0, del_bytes: 0 };
1664-
1665-
while guard.idx < len {
1666-
let ch =
1667-
// SAFETY: `guard.idx` is positive-or-zero and less that len so the `get_unchecked`
1668-
// is in bound. `self` is valid UTF-8 like string and the returned slice starts at
1669-
// a unicode code point so the `Chars` always return one character.
1670-
unsafe { guard.s.get_unchecked(guard.idx..len).chars().next().unwrap_unchecked() };
1671-
let ch_len = ch.len_utf8();
1672-
1673-
if !f(ch) {
1674-
guard.del_bytes += ch_len;
1675-
} else if guard.del_bytes > 0 {
1676-
// SAFETY: `guard.idx` is in bound and `guard.del_bytes` represent the number of
1677-
// bytes that are erased from the string so the resulting `guard.idx -
1678-
// guard.del_bytes` always represent a valid unicode code point.
1679-
//
1680-
// `guard.del_bytes` >= `ch.len_utf8()`, so taking a slice with `ch.len_utf8()` len
1681-
// is safe.
1682-
ch.encode_utf8(unsafe {
1683-
crate::slice::from_raw_parts_mut(
1684-
guard.s.as_mut_ptr().add(guard.idx - guard.del_bytes),
1685-
ch.len_utf8(),
1686-
)
1687-
});
1670+
// Faster read-path
1671+
let string_ptr = ptr::NonNull::from(&mut *self);
1672+
let data_ptr = self.vec.as_mut_ptr();
1673+
let mut chars = self.char_indices();
1674+
let (read, write) = loop {
1675+
let Some((write, ch)) = chars.next() else { return };
1676+
if hint::unlikely(!f(ch)) {
1677+
break (chars.offset(), write);
16881678
}
1679+
};
16891680

1690-
// Point idx to the next char
1691-
guard.idx += ch_len;
1681+
// Critical section starts here, at least one character is going to be removed.
1682+
let mut g = PanicGuard { s: string_ptr, read, write };
1683+
// Slower write-path
1684+
while let Some((read, ch)) = chars.next() {
1685+
let ch_len = chars.offset() - read;
1686+
if f(ch) {
1687+
// SAFETY: `g.read` is in bound because `g.write` <= `g.read` - `ch_len`,
1688+
// so taking a slice with `ch_len` is safe.
1689+
unsafe {
1690+
ptr::copy(data_ptr.add(g.read), data_ptr.add(g.write), ch_len);
1691+
}
1692+
g.write += ch_len;
1693+
}
1694+
g.read += ch_len;
16921695
}
16931696

1694-
drop(guard);
1697+
// All characters have been processed, set the final length by dropping the guard.
1698+
drop(g);
16951699
}
16961700

16971701
/// Inserts a character into this `String` at byte position `idx`.

0 commit comments

Comments
 (0)