scuffle_bytes_util/
bit_write.rs

1use std::io;
2
3/// A writer that allows you to write bits to a stream
4#[derive(Debug)]
5#[must_use]
6pub struct BitWriter<W> {
7    bit_pos: u8,
8    current_byte: u8,
9    writer: W,
10}
11
12impl<W: Default> Default for BitWriter<W> {
13    fn default() -> Self {
14        Self::new(W::default())
15    }
16}
17
18impl<W: io::Write> BitWriter<W> {
19    /// Writes a single bit to the stream
20    pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
21        if bit {
22            self.current_byte |= 1 << (7 - self.bit_pos);
23        } else {
24            self.current_byte &= !(1 << (7 - self.bit_pos));
25        }
26
27        self.bit_pos += 1;
28
29        if self.bit_pos == 8 {
30            self.writer.write_all(&[self.current_byte])?;
31            self.current_byte = 0;
32            self.bit_pos = 0;
33        }
34
35        Ok(())
36    }
37
38    /// Writes a number of bits to the stream (the most significant bit is
39    /// written first).
40    ///
41    /// If `count` is less than 64 the lower bits are written.
42    pub fn write_bits(&mut self, bits: u64, count: u8) -> io::Result<()> {
43        let count = count.min(64);
44
45        if count != 64 && bits > (1 << count as u64) - 1 {
46            return Err(io::Error::new(io::ErrorKind::InvalidData, "bits too large to write"));
47        }
48
49        for i in 0..count {
50            let bit = (bits >> (count - i - 1)) & 1 == 1;
51            self.write_bit(bit)?;
52        }
53
54        Ok(())
55    }
56
57    /// Flushes the buffer and returns the underlying writer
58    /// This will also align the writer to the byte boundary
59    pub fn finish(mut self) -> io::Result<W> {
60        self.align()?;
61        Ok(self.writer)
62    }
63
64    /// Aligns the writer to the byte boundary
65    pub fn align(&mut self) -> io::Result<()> {
66        if !self.is_aligned() {
67            self.write_bits(0, 8 - self.bit_pos())?;
68        }
69
70        Ok(())
71    }
72}
73
74impl<W> BitWriter<W> {
75    /// Creates a new BitWriter from a writer
76    pub const fn new(writer: W) -> Self {
77        Self {
78            bit_pos: 0,
79            current_byte: 0,
80            writer,
81        }
82    }
83
84    /// Returns the current bit position (0-7)
85    #[inline(always)]
86    #[must_use]
87    pub const fn bit_pos(&self) -> u8 {
88        self.bit_pos % 8
89    }
90
91    /// Checks if the writer is aligned to the byte boundary
92    #[inline(always)]
93    #[must_use]
94    pub const fn is_aligned(&self) -> bool {
95        self.bit_pos.is_multiple_of(8)
96    }
97
98    /// Returns a reference to the underlying writer
99    #[inline(always)]
100    #[must_use]
101    pub const fn get_ref(&self) -> &W {
102        &self.writer
103    }
104}
105
106impl<W: io::Write> io::Write for BitWriter<W> {
107    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
108        if self.is_aligned() {
109            return self.writer.write(buf);
110        }
111
112        for byte in buf {
113            self.write_bits(*byte as u64, 8)?;
114        }
115
116        Ok(buf.len())
117    }
118
119    fn flush(&mut self) -> io::Result<()> {
120        self.writer.flush()
121    }
122}
123
124#[cfg(test)]
125#[cfg_attr(all(test, coverage_nightly), coverage(off))]
126mod tests {
127    use io::Write;
128
129    use super::*;
130
131    #[test]
132    fn test_bit_writer() {
133        let mut bit_writer = BitWriter::<Vec<u8>>::default();
134
135        bit_writer.write_bits(0b11111111, 8).unwrap();
136        assert_eq!(bit_writer.bit_pos(), 0);
137        assert!(bit_writer.is_aligned());
138
139        bit_writer.write_bits(0b0000, 4).unwrap();
140        assert_eq!(bit_writer.bit_pos(), 4);
141        assert!(!bit_writer.is_aligned());
142        bit_writer.align().unwrap();
143        assert_eq!(bit_writer.bit_pos(), 0);
144        assert!(bit_writer.is_aligned());
145
146        bit_writer.write_bits(0b1010, 4).unwrap();
147        assert_eq!(bit_writer.bit_pos(), 4);
148        assert!(!bit_writer.is_aligned());
149
150        bit_writer.write_bits(0b101010101010, 12).unwrap();
151        assert_eq!(bit_writer.bit_pos(), 0);
152        assert!(bit_writer.is_aligned());
153
154        bit_writer.write_bit(true).unwrap();
155        assert_eq!(bit_writer.bit_pos(), 1);
156        assert!(!bit_writer.is_aligned());
157
158        let err = bit_writer.write_bits(0b10000, 4).unwrap_err();
159        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
160        assert_eq!(err.to_string(), "bits too large to write");
161
162        assert_eq!(
163            bit_writer.finish().unwrap(),
164            vec![0b11111111, 0b00000000, 0b10101010, 0b10101010, 0b10000000]
165        );
166    }
167
168    #[test]
169    fn test_flush_buffer() {
170        let mut bit_writer = BitWriter::<Vec<u8>>::default();
171
172        bit_writer.write_bits(0b11111111, 8).unwrap();
173        assert_eq!(bit_writer.bit_pos(), 0);
174        assert!(bit_writer.is_aligned());
175        assert_eq!(bit_writer.get_ref(), &[0b11111111], "underlying writer should have one byte");
176
177        bit_writer.write_bits(0b0000, 4).unwrap();
178        assert_eq!(bit_writer.bit_pos(), 4);
179        assert!(!bit_writer.is_aligned());
180        assert_eq!(bit_writer.get_ref(), &[0b11111111], "underlying writer should have one bytes");
181
182        bit_writer.write_bits(0b1010, 4).unwrap();
183        assert_eq!(bit_writer.bit_pos(), 0);
184        assert!(bit_writer.is_aligned());
185        assert_eq!(
186            bit_writer.get_ref(),
187            &[0b11111111, 0b00001010],
188            "underlying writer should have two bytes"
189        );
190    }
191
192    #[test]
193    fn test_io_write() {
194        let mut inner = Vec::new();
195        let mut bit_writer = BitWriter::new(&mut inner);
196
197        bit_writer.write_bits(0b11111111, 8).unwrap();
198        assert_eq!(bit_writer.bit_pos(), 0);
199        assert!(bit_writer.is_aligned());
200        // We should have buffered the write
201        assert_eq!(bit_writer.get_ref().as_slice(), &[255]);
202
203        bit_writer.write_all(&[1, 2, 3]).unwrap();
204        assert_eq!(bit_writer.bit_pos(), 0);
205        assert!(bit_writer.is_aligned());
206        // since we did an io::Write on an aligned bit_writer
207        // we should have written directly to the underlying
208        // writer
209        assert_eq!(bit_writer.get_ref().as_slice(), &[255, 1, 2, 3]);
210
211        bit_writer.write_bit(true).unwrap();
212
213        bit_writer.write_bits(0b1010, 4).unwrap();
214
215        bit_writer
216            .write_all(&[0b11111111, 0b00000000, 0b11111111, 0b00000000])
217            .unwrap();
218
219        // Since the writer was not aligned we should have buffered the writes
220        assert_eq!(
221            bit_writer.get_ref().as_slice(),
222            &[255, 1, 2, 3, 0b11010111, 0b11111000, 0b00000111, 0b11111000]
223        );
224
225        bit_writer.finish().unwrap();
226
227        assert_eq!(
228            inner,
229            vec![255, 1, 2, 3, 0b11010111, 0b11111000, 0b00000111, 0b11111000, 0b00000000]
230        );
231    }
232
233    #[test]
234    fn test_flush() {
235        let mut inner = Vec::new();
236        let mut bit_writer = BitWriter::new(&mut inner);
237
238        bit_writer.write_bits(0b10100000, 8).unwrap();
239
240        bit_writer.flush().unwrap();
241
242        assert_eq!(bit_writer.get_ref().as_slice(), &[0b10100000]);
243        assert_eq!(bit_writer.bit_pos(), 0);
244        assert!(bit_writer.is_aligned());
245    }
246}