1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
//! Implementations of radix keys and sorting functions.
use alloc::vec::Vec;
use core::mem;
use crate::Key;
/// Unsigned integers used as sorting keys for radix sort.
///
/// These keys can be sorted bitwise. For conversion from scalar types, see
/// [`Scalar::to_radix_key()`].
///
/// Mapping of floating point numbers onto unsigned integers is a bit more
/// involved, see the `Key` implementation for details.
///
/// [`Scalar::to_radix_key()`]: fn.sort_by_key.html
/// [`Key::to_radix_key`]: fn.sort_by_key.html
pub trait RadixKey: Key {
/// Sorts the slice using provided key extraction function.
/// Runs one of the other functions, based on the length of the slice.
#[inline]
fn radix_sort<T, F>(slice: &mut[T], mut key_fn: F, unopt: bool)
where F: FnMut(&T) -> Self
{
// Sorting has no meaningful behavior on zero-sized types.
if mem::size_of::<T>() == 0 {
return;
}
let len = slice.len();
if len < 2 {
return;
}
#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
{
if len <= core::u32::MAX as usize {
Self::radix_sort_u32(slice, |t| key_fn(t), unopt);
return;
}
}
Self::radix_sort_usize(slice, |t| key_fn(t), unopt);
}
/// Sorting for slices with up to u32::MAX elements, which is a majority of
/// cases. Uses u32 indices for histograms and offsets to save cache space.
#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
fn radix_sort_u32<T, F>(slice: &mut[T], key_fn: F, unopt: bool)
where F: FnMut(&T) -> Self;
/// Sorting function for slices with up to usize::MAX elements.
fn radix_sort_usize<T, F>(slice: &mut[T], key_fn: F, unopt: bool)
where F: FnMut(&T) -> Self;
}
macro_rules! sort_impl {
($name:ident, $radix_key_type:ty, $offset_type:ty) => {
#[inline(never)] // Don't inline, the offset array needs a lot of stack
fn $name<T, F>(input: &mut [T], mut key_fn: F, unopt: bool)
where F: FnMut(&T) -> $radix_key_type
{
// This implementation is radix 256, so the size of a digit is 8 bits / one byte.
// You can experiment with different digit sizes by changing this constant, but
// according to my benchmarks, the overhead from arbitrary shifting and masking
// will be higher than what you save by having less digits.
const DIGIT_BITS: usize = 8;
const RADIX_KEY_BITS: usize = mem::size_of::<$radix_key_type>() * 8;
// Have one bucket for each possible value of the digit
const BUCKET_COUNT: usize = 1 << DIGIT_BITS;
const DIGIT_COUNT: usize = (RADIX_KEY_BITS + DIGIT_BITS - 1) / DIGIT_BITS;
let digit_skip_enabled: bool = !unopt;
/// Extracts the digit from the key, starting with the least significant digit.
/// The digit is used as a bucket index.
#[inline(always)]
fn extract_digit(key: $radix_key_type, digit: usize) -> usize {
const DIGIT_MASK: $radix_key_type = ((1 << DIGIT_BITS) - 1) as $radix_key_type;
((key >> (digit * DIGIT_BITS)) & DIGIT_MASK) as usize
}
// In the worst case (u128 key, input len >= u32::MAX) uses 32 KiB on the stack.
let mut offsets = [[0 as $offset_type; BUCKET_COUNT]; DIGIT_COUNT];
let mut skip_digit = [false; DIGIT_COUNT];
{ // Calculate bucket offsets for each digit.
// Calculate histograms/bucket sizes and store in `offsets`.
for t in input.iter() {
let key = key_fn(t);
for digit in 0..DIGIT_COUNT {
offsets[digit][extract_digit(key, digit)] += 1;
}
}
if digit_skip_enabled {
// For each digit, gheck if all the elements are in the same bucket.
// If so, we can skip the whole digit. Instead of checking all the buckets,
// we pick a key and check whether the bucket contains all the elements.
let last_key = key_fn(input.last().unwrap());
for digit in 0..DIGIT_COUNT {
let last_bucket = extract_digit(last_key, digit);
let skip = offsets[digit][last_bucket] == input.len() as $offset_type;
skip_digit[digit] = skip;
}
}
// Turn the histogram/bucket sizes into bucket offsets by calculating a prefix sum.
// Sizes: |---b1---|-b2-|---b3---|----b4----|
// Offsets: 0 b1 b1+b2 b1+b2+b3
for digit in 0..DIGIT_COUNT {
if !(digit_skip_enabled && skip_digit[digit]) {
let mut offset_acc = 0;
for count in offsets[digit].iter_mut() {
let offset = offset_acc;
offset_acc += *count;
*count = offset;
}
}
}
// The `offsets` array now contains bucket offsets for each digit.
}
let len = input.len();
// Drop impl of DoubleBuffer ensures that `input` is consistent,
// e.g. in case of panic in the key function.
let mut buffer = DoubleBuffer::new(input);
// This is the main sorting loop. We sort the elements by each digit of the key,
// starting from the least-significant. After sorting by the last, most significant
// digit, our elements are sorted.
for digit in 0..DIGIT_COUNT {
if !(digit_skip_enabled && skip_digit[digit]) {
// Copy the offsets. We need the original later for a consistency check.
// As we write elements into each bucket, we increment the bucket offset
// so that it points to the next empty slot.
let mut working_offsets: [$offset_type; BUCKET_COUNT] = offsets[digit];
for r_pos in 0..len {
let t: &T = unsafe {
// This is safe, r_pos is in (0..len)
buffer.read(r_pos)
};
let key = key_fn(t);
let bucket = extract_digit(key, digit);
let offset = &mut working_offsets[bucket];
unsafe {
// Make sure the offset is in bounds. An unreliable key function, which
// returns different keys for the same element when called repeatedly,
// can cause offsets to go out of bounds.
let w_pos = usize::min(*offset as usize, len - 1);
// This is safe, w_pos is in (0..len)
buffer.write(w_pos, t);
}
// Increment the offset of the bucket. Use wrapping add in case the
// key function is unreliable and the bucket overflowed.
*offset = offset.wrapping_add(1);
}
// Check that each bucket had the same number of insertions as we expected.
// If this is not true, then the key function is unreliable and the write buffer
// is not consistent: some elements were overwritten, some were not written to
// and contain garbage.
//
// If the key function is unreliable, but the sizes of buckets ended up being
// the same, it would not get detected. This is sound, the only consequence is
// that the elements won't be sorted right.
{
// The `working_offsets` array now contains the end offset of each bucket.
// If the bucket is full, the working offset is now equal to the original
// offset of the next bucket. The working offset of the last bucket should
// be equal to the number of elements.
let bucket_sizes_match =
working_offsets[0..BUCKET_COUNT-1] == offsets[digit][1..BUCKET_COUNT] &&
working_offsets[BUCKET_COUNT-1] == len as $offset_type;
if !bucket_sizes_match {
// The bucket sizes do not match expected sizes, the key function is
// unreliable (programming mistake).
//
// Drop impl of the double buffer will make sure that the input slice is
// consistent. This would happen automatically, but I'm making it
// explicit for clarity.
drop(buffer);
panic!("The key function is not reliable: when called repeatedly, \
it returned different keys for the same element.")
}
}
unsafe {
// This is safe, we just ensured that the write buffer is consistent.
buffer.swap();
}
}
}
// In case the result ended up in the temporary buffer, the Drop impl will copy it over
// to the input slice. This would happen automatically, but I'm making it explicit for
// clarity.
drop(buffer);
}
}
}
macro_rules! radix_key_impl {
($($key_type:ty)*) => ($(
impl RadixKey for $key_type {
#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
sort_impl!(radix_sort_u32, $key_type, u32);
sort_impl!(radix_sort_usize, $key_type, usize);
}
)*)
}
radix_key_impl!{ u8 u16 u32 u64 u128 }
/// Double buffer. Allocates a temporary memory the size of the slice, so that
/// elements can be freely reordered from buffer to buffer.
///
/// # Consistency
///
/// For the purposes of this struct, buffer in a consistent state contains a
/// permutation of elements from the original slice. In other words, elements
/// can be reordered, but not duplicated or lost.
///
/// `read_buf` is always consistent. Before calling `swap`, the caller must
/// ensure that `write_buf` is also consistent.
///
/// # Drop behavior
///
/// Drop impl ensures that the slice this buffer was constructed with is left in
/// a consistent state. If the input slice ended up as `write_buf`, the
/// temporary memory (which is now `read_buf` and therefore consistent) is
/// copied into the slice and the buffers are swapped.
struct DoubleBuffer<'a, T> {
slice: &'a mut [T],
_aux: Vec<T>,
/// Read buffer is read-only and always consistent.
read_buf: *const T,
/// Write buffer is write-only. Elements can be present multiple times or
/// not at all. The caller must ensure that by the time of calling `swap`
/// it has a complete set of elements and each element is present exactly
/// once.
write_buf: *mut T,
}
impl<'a, T> DoubleBuffer<'a, T> {
fn new(slice: &'a mut [T]) -> DoubleBuffer<'a, T> {
let mut aux = Vec::with_capacity(slice.len());
let read_buf = slice.as_ptr();
let write_buf = aux.as_mut_ptr();
DoubleBuffer {
// Hold on to the &mut slice to make sure it outlives the pointer
// and to prevent writes from the outside
slice,
// Hold on to the Vec to make sure it outlives the pointer
_aux: aux,
read_buf,
write_buf,
}
}
/// Returns a ref to an element from the read buffer.
///
/// Caller must ensure that `index` is in (0..len).
#[inline(always)]
unsafe fn read(&self, index: usize) -> &T {
&*self.read_buf.add(index)
}
/// Copies the referenced element into the write buffer.
///
/// Caller must ensure that `index` is in (0..len).
#[inline(always)]
unsafe fn write(&self, index: usize, t: &T) {
self.write_buf
.add(index)
.copy_from_nonoverlapping(t as *const T, 1);
}
/// Swaps the read and write buffers.
///
/// Caller must ensure that the write buffer is consistent before calling
/// this function.
unsafe fn swap(&mut self) {
// The cast is ok, we have an exclusive access to both buffers
// (&mut [T] and Vec<T>). The user guarantees that the write buffer is
// consistent and therefore it's safe to read from it and use it as a
// read buffer.
let temp = self.write_buf as *const T;
self.write_buf = self.read_buf as *mut T;
self.read_buf = temp;
}
}
impl<'a, T> Drop for DoubleBuffer<'a, T> {
fn drop(&mut self) {
let input_slice_is_write = self.write_buf as *const T == self.slice.as_ptr();
if input_slice_is_write {
// Input slice is the write buffer, copy the consistent state from the read buffer
unsafe {
// This is safe, `read_buf` is always consistent and the length is the same.
self.write_buf.copy_from_nonoverlapping(self.read_buf, self.slice.len());
self.swap();
}
}
}
}