1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionType};
use super::bit_reader_reverse::{BitReaderReversed, GetBitsError};
use super::scratch::HuffmanScratch;
use crate::huff0::{HuffmanDecoder, HuffmanDecoderError, HuffmanTableError};
use alloc::vec::Vec;

#[derive(Debug, derive_more::Display, derive_more::From)]
#[cfg_attr(feature = "std", derive(derive_more::Error))]
#[non_exhaustive]
pub enum DecompressLiteralsError {
    #[display(
        fmt = "compressed size was none even though it must be set to something for compressed literals"
    )]
    MissingCompressedSize,
    #[display(
        fmt = "num_streams was none even though it must be set to something (1 or 4) for compressed literals"
    )]
    MissingNumStreams,
    #[display(fmt = "{_0:?}")]
    #[from]
    GetBitsError(GetBitsError),
    #[display(fmt = "{_0:?}")]
    #[from]
    HuffmanTableError(HuffmanTableError),
    #[display(fmt = "{_0:?}")]
    #[from]
    HuffmanDecoderError(HuffmanDecoderError),
    #[display(fmt = "Tried to reuse huffman table but it was never initialized")]
    UninitializedHuffmanTable,
    #[display(fmt = "Need 6 bytes to decode jump header, got {got} bytes")]
    MissingBytesForJumpHeader { got: usize },
    #[display(fmt = "Need at least {needed} bytes to decode literals. Have: {got} bytes")]
    MissingBytesForLiterals { got: usize, needed: usize },
    #[display(
        fmt = "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption"
    )]
    ExtraPadding { skipped_bits: i32 },
    #[display(fmt = "Bitstream was read till: {read_til}, should have been: {expected}")]
    BitstreamReadMismatch { read_til: isize, expected: isize },
    #[display(fmt = "Did not decode enough literals: {decoded}, Should have been: {expected}")]
    DecodedLiteralCountMismatch { decoded: usize, expected: usize },
}

pub fn decode_literals(
    section: &LiteralsSection,
    scratch: &mut HuffmanScratch,
    source: &[u8],
    target: &mut Vec<u8>,
) -> Result<u32, DecompressLiteralsError> {
    match section.ls_type {
        LiteralsSectionType::Raw => {
            target.extend(&source[0..section.regenerated_size as usize]);
            Ok(section.regenerated_size)
        }
        LiteralsSectionType::RLE => {
            target.resize(target.len() + section.regenerated_size as usize, source[0]);
            Ok(1)
        }
        LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
            let bytes_read = decompress_literals(section, scratch, source, target)?;

            //return sum of used bytes
            Ok(bytes_read)
        }
    }
}

fn decompress_literals(
    section: &LiteralsSection,
    scratch: &mut HuffmanScratch,
    source: &[u8],
    target: &mut Vec<u8>,
) -> Result<u32, DecompressLiteralsError> {
    use DecompressLiteralsError as err;

    let compressed_size = section.compressed_size.ok_or(err::MissingCompressedSize)? as usize;
    let num_streams = section.num_streams.ok_or(err::MissingNumStreams)?;

    target.reserve(section.regenerated_size as usize);
    let source = &source[0..compressed_size];
    let mut bytes_read = 0;

    match section.ls_type {
        LiteralsSectionType::Compressed => {
            //read Huffman tree description
            bytes_read += scratch.table.build_decoder(source)?;
            vprintln!("Built huffman table using {} bytes", bytes_read);
        }
        LiteralsSectionType::Treeless => {
            if scratch.table.max_num_bits == 0 {
                return Err(err::UninitializedHuffmanTable);
            }
        }
        _ => { /* nothing to do, huffman tree has been provided by previous block */ }
    }

    let source = &source[bytes_read as usize..];

    if num_streams == 4 {
        //build jumptable
        if source.len() < 6 {
            return Err(err::MissingBytesForJumpHeader { got: source.len() });
        }
        let jump1 = source[0] as usize + ((source[1] as usize) << 8);
        let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8);
        let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8);
        bytes_read += 6;
        let source = &source[6..];

        if source.len() < jump3 {
            return Err(err::MissingBytesForLiterals {
                got: source.len(),
                needed: jump3,
            });
        }

        //decode 4 streams
        let stream1 = &source[..jump1];
        let stream2 = &source[jump1..jump2];
        let stream3 = &source[jump2..jump3];
        let stream4 = &source[jump3..];

        for stream in &[stream1, stream2, stream3, stream4] {
            let mut decoder = HuffmanDecoder::new(&scratch.table);
            let mut br = BitReaderReversed::new(stream);
            //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
            let mut skipped_bits = 0;
            loop {
                let val = br.get_bits(1)?;
                skipped_bits += 1;
                if val == 1 || skipped_bits > 8 {
                    break;
                }
            }
            if skipped_bits > 8 {
                //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
                return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
            }
            decoder.init_state(&mut br)?;

            while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
                target.push(decoder.decode_symbol());
                decoder.next_state(&mut br)?;
            }
            if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
                return Err(DecompressLiteralsError::BitstreamReadMismatch {
                    read_til: br.bits_remaining(),
                    expected: -(scratch.table.max_num_bits as isize),
                });
            }
        }

        bytes_read += source.len() as u32;
    } else {
        //just decode the one stream
        assert!(num_streams == 1);
        let mut decoder = HuffmanDecoder::new(&scratch.table);
        let mut br = BitReaderReversed::new(source);
        let mut skipped_bits = 0;
        loop {
            let val = br.get_bits(1)?;
            skipped_bits += 1;
            if val == 1 || skipped_bits > 8 {
                break;
            }
        }
        if skipped_bits > 8 {
            //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
            return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
        }
        decoder.init_state(&mut br)?;
        while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
            target.push(decoder.decode_symbol());
            decoder.next_state(&mut br)?;
        }
        bytes_read += source.len() as u32;
    }

    if target.len() != section.regenerated_size as usize {
        return Err(DecompressLiteralsError::DecodedLiteralCountMismatch {
            decoded: target.len(),
            expected: section.regenerated_size as usize,
        });
    }

    Ok(bytes_read)
}