use crate::decoding::bit_reader::BitReader;
use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
use alloc::vec::Vec;
pub struct FSETable {
pub decode: Vec<Entry>, pub accuracy_log: u8,
pub symbol_probabilities: Vec<i32>, symbol_counter: Vec<u32>,
}
impl Default for FSETable {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, derive_more::Display, derive_more::From)]
#[cfg_attr(feature = "std", derive(derive_more::Error))]
#[non_exhaustive]
pub enum FSETableError {
#[display(fmt = "Acclog must be at least 1")]
AccLogIsZero,
#[display(fmt = "Found FSE acc_log: {got} bigger than allowed maximum in this case: {max}")]
AccLogTooBig { got: u8, max: u8 },
#[display(fmt = "{_0:?}")]
#[from]
GetBitsError(GetBitsError),
#[display(
fmt = "The counter ({got}) exceeded the expected sum: {expected_sum}. This means an error or corrupted data \n {symbol_probabilities:?}"
)]
ProbabilityCounterMismatch {
got: u32,
expected_sum: u32,
symbol_probabilities: Vec<i32>,
},
#[display(fmt = "There are too many symbols in this distribution: {got}. Max: 256")]
TooManySymbols { got: usize },
}
pub struct FSEDecoder<'table> {
pub state: Entry,
table: &'table FSETable,
}
#[derive(Debug, derive_more::Display, derive_more::From)]
#[cfg_attr(feature = "std", derive(derive_more::Error))]
#[non_exhaustive]
pub enum FSEDecoderError {
#[display(fmt = "{_0:?}")]
#[from]
GetBitsError(GetBitsError),
#[display(fmt = "Tried to use an uninitialized table!")]
TableIsUninitialized,
}
#[derive(Copy, Clone)]
pub struct Entry {
pub base_line: u32,
pub num_bits: u8,
pub symbol: u8,
}
const ACC_LOG_OFFSET: u8 = 5;
fn highest_bit_set(x: u32) -> u32 {
assert!(x > 0);
u32::BITS - x.leading_zeros()
}
impl<'t> FSEDecoder<'t> {
pub fn new(table: &'t FSETable) -> FSEDecoder<'_> {
FSEDecoder {
state: table.decode.first().copied().unwrap_or(Entry {
base_line: 0,
num_bits: 0,
symbol: 0,
}),
table,
}
}
pub fn decode_symbol(&self) -> u8 {
self.state.symbol
}
pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> {
if self.table.accuracy_log == 0 {
return Err(FSEDecoderError::TableIsUninitialized);
}
self.state = self.table.decode[bits.get_bits(self.table.accuracy_log)? as usize];
Ok(())
}
pub fn update_state(
&mut self,
bits: &mut BitReaderReversed<'_>,
) -> Result<(), FSEDecoderError> {
let num_bits = self.state.num_bits;
let add = bits.get_bits(num_bits)?;
let base_line = self.state.base_line;
let new_state = base_line + add as u32;
self.state = self.table.decode[new_state as usize];
Ok(())
}
}
impl FSETable {
pub fn new() -> FSETable {
FSETable {
symbol_probabilities: Vec::with_capacity(256), symbol_counter: Vec::with_capacity(256), decode: Vec::new(), accuracy_log: 0,
}
}
pub fn reinit_from(&mut self, other: &Self) {
self.reset();
self.symbol_counter.extend_from_slice(&other.symbol_counter);
self.symbol_probabilities
.extend_from_slice(&other.symbol_probabilities);
self.decode.extend_from_slice(&other.decode);
self.accuracy_log = other.accuracy_log;
}
pub fn reset(&mut self) {
self.symbol_counter.clear();
self.symbol_probabilities.clear();
self.decode.clear();
self.accuracy_log = 0;
}
pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
self.accuracy_log = 0;
let bytes_read = self.read_probabilities(source, max_log)?;
self.build_decoding_table();
Ok(bytes_read)
}
pub fn build_from_probabilities(
&mut self,
acc_log: u8,
probs: &[i32],
) -> Result<(), FSETableError> {
if acc_log == 0 {
return Err(FSETableError::AccLogIsZero);
}
self.symbol_probabilities = probs.to_vec();
self.accuracy_log = acc_log;
self.build_decoding_table();
Ok(())
}
fn build_decoding_table(&mut self) {
self.decode.clear();
let table_size = 1 << self.accuracy_log;
if self.decode.len() < table_size {
self.decode.reserve(table_size - self.decode.len());
}
self.decode.resize(
table_size,
Entry {
base_line: 0,
num_bits: 0,
symbol: 0,
},
);
let mut negative_idx = table_size; for symbol in 0..self.symbol_probabilities.len() {
if self.symbol_probabilities[symbol] == -1 {
negative_idx -= 1;
let entry = &mut self.decode[negative_idx];
entry.symbol = symbol as u8;
entry.base_line = 0;
entry.num_bits = self.accuracy_log;
}
}
let mut position = 0;
for idx in 0..self.symbol_probabilities.len() {
let symbol = idx as u8;
if self.symbol_probabilities[idx] <= 0 {
continue;
}
let prob = self.symbol_probabilities[idx];
for _ in 0..prob {
let entry = &mut self.decode[position];
entry.symbol = symbol;
position = next_position(position, table_size);
while position >= negative_idx {
position = next_position(position, table_size);
}
}
}
self.symbol_counter.clear();
self.symbol_counter
.resize(self.symbol_probabilities.len(), 0);
for idx in 0..negative_idx {
let entry = &mut self.decode[idx];
let symbol = entry.symbol;
let prob = self.symbol_probabilities[symbol as usize];
let symbol_count = self.symbol_counter[symbol as usize];
let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count);
assert!(nb <= self.accuracy_log);
self.symbol_counter[symbol as usize] += 1;
entry.base_line = bl;
entry.num_bits = nb;
}
}
fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
self.symbol_probabilities.clear(); let mut br = BitReader::new(source);
self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
if self.accuracy_log > max_log {
return Err(FSETableError::AccLogTooBig {
got: self.accuracy_log,
max: max_log,
});
}
if self.accuracy_log == 0 {
return Err(FSETableError::AccLogIsZero);
}
let probablility_sum = 1 << self.accuracy_log;
let mut probability_counter = 0;
while probability_counter < probablility_sum {
let max_remaining_value = probablility_sum - probability_counter + 1;
let bits_to_read = highest_bit_set(max_remaining_value);
let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value);
let mask = (1 << (bits_to_read - 1)) - 1;
let small_value = unchecked_value & mask;
let value = if small_value < low_threshold {
br.return_bits(1);
small_value
} else if unchecked_value > mask {
unchecked_value - low_threshold
} else {
unchecked_value
};
let prob = (value as i32) - 1;
self.symbol_probabilities.push(prob);
if prob != 0 {
if prob > 0 {
probability_counter += prob as u32;
} else {
assert!(prob == -1);
probability_counter += 1;
}
} else {
loop {
let skip_amount = br.get_bits(2)? as usize;
self.symbol_probabilities
.resize(self.symbol_probabilities.len() + skip_amount, 0);
if skip_amount != 3 {
break;
}
}
}
}
if probability_counter != probablility_sum {
return Err(FSETableError::ProbabilityCounterMismatch {
got: probability_counter,
expected_sum: probablility_sum,
symbol_probabilities: self.symbol_probabilities.clone(),
});
}
if self.symbol_probabilities.len() > 256 {
return Err(FSETableError::TooManySymbols {
got: self.symbol_probabilities.len(),
});
}
let bytes_read = if br.bits_read() % 8 == 0 {
br.bits_read() / 8
} else {
(br.bits_read() / 8) + 1
};
Ok(bytes_read)
}
}
fn next_position(mut p: usize, table_size: usize) -> usize {
p += (table_size >> 1) + (table_size >> 3) + 3;
p &= table_size - 1;
p
}
fn calc_baseline_and_numbits(
num_states_total: u32,
num_states_symbol: u32,
state_number: u32,
) -> (u32, u8) {
let num_state_slices = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol {
num_states_symbol
} else {
1 << (highest_bit_set(num_states_symbol))
}; let num_double_width_state_slices = num_state_slices - num_states_symbol; let num_single_width_state_slices = num_states_symbol - num_double_width_state_slices; let slice_width = num_states_total / num_state_slices; let num_bits = highest_bit_set(slice_width) - 1; if state_number < num_double_width_state_slices {
let baseline = num_single_width_state_slices * slice_width + state_number * slice_width * 2;
(baseline, num_bits as u8 + 1)
} else {
let index_shifted = state_number - num_double_width_state_slices;
((index_shifted * slice_width), num_bits as u8)
}
}