scuffle_bytes_util/
bit_write.rs1use std::io;
2
3#[derive(Debug)]
5#[must_use]
6pub struct BitWriter<W> {
7 bit_pos: u8,
8 current_byte: u8,
9 writer: W,
10}
11
12impl<W: Default> Default for BitWriter<W> {
13 fn default() -> Self {
14 Self::new(W::default())
15 }
16}
17
18impl<W: io::Write> BitWriter<W> {
19 pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
21 if bit {
22 self.current_byte |= 1 << (7 - self.bit_pos);
23 } else {
24 self.current_byte &= !(1 << (7 - self.bit_pos));
25 }
26
27 self.bit_pos += 1;
28
29 if self.bit_pos == 8 {
30 self.writer.write_all(&[self.current_byte])?;
31 self.current_byte = 0;
32 self.bit_pos = 0;
33 }
34
35 Ok(())
36 }
37
38 pub fn write_bits(&mut self, bits: u64, count: u8) -> io::Result<()> {
43 let count = count.min(64);
44
45 if count != 64 && bits > (1 << count as u64) - 1 {
46 return Err(io::Error::new(io::ErrorKind::InvalidData, "bits too large to write"));
47 }
48
49 for i in 0..count {
50 let bit = (bits >> (count - i - 1)) & 1 == 1;
51 self.write_bit(bit)?;
52 }
53
54 Ok(())
55 }
56
57 pub fn finish(mut self) -> io::Result<W> {
60 self.align()?;
61 Ok(self.writer)
62 }
63
64 pub fn align(&mut self) -> io::Result<()> {
66 if !self.is_aligned() {
67 self.write_bits(0, 8 - self.bit_pos())?;
68 }
69
70 Ok(())
71 }
72}
73
74impl<W> BitWriter<W> {
75 pub const fn new(writer: W) -> Self {
77 Self {
78 bit_pos: 0,
79 current_byte: 0,
80 writer,
81 }
82 }
83
84 #[inline(always)]
86 #[must_use]
87 pub const fn bit_pos(&self) -> u8 {
88 self.bit_pos % 8
89 }
90
91 #[inline(always)]
93 #[must_use]
94 pub const fn is_aligned(&self) -> bool {
95 self.bit_pos.is_multiple_of(8)
96 }
97
98 #[inline(always)]
100 #[must_use]
101 pub const fn get_ref(&self) -> &W {
102 &self.writer
103 }
104}
105
106impl<W: io::Write> io::Write for BitWriter<W> {
107 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
108 if self.is_aligned() {
109 return self.writer.write(buf);
110 }
111
112 for byte in buf {
113 self.write_bits(*byte as u64, 8)?;
114 }
115
116 Ok(buf.len())
117 }
118
119 fn flush(&mut self) -> io::Result<()> {
120 self.writer.flush()
121 }
122}
123
124#[cfg(test)]
125#[cfg_attr(all(test, coverage_nightly), coverage(off))]
126mod tests {
127 use io::Write;
128
129 use super::*;
130
131 #[test]
132 fn test_bit_writer() {
133 let mut bit_writer = BitWriter::<Vec<u8>>::default();
134
135 bit_writer.write_bits(0b11111111, 8).unwrap();
136 assert_eq!(bit_writer.bit_pos(), 0);
137 assert!(bit_writer.is_aligned());
138
139 bit_writer.write_bits(0b0000, 4).unwrap();
140 assert_eq!(bit_writer.bit_pos(), 4);
141 assert!(!bit_writer.is_aligned());
142 bit_writer.align().unwrap();
143 assert_eq!(bit_writer.bit_pos(), 0);
144 assert!(bit_writer.is_aligned());
145
146 bit_writer.write_bits(0b1010, 4).unwrap();
147 assert_eq!(bit_writer.bit_pos(), 4);
148 assert!(!bit_writer.is_aligned());
149
150 bit_writer.write_bits(0b101010101010, 12).unwrap();
151 assert_eq!(bit_writer.bit_pos(), 0);
152 assert!(bit_writer.is_aligned());
153
154 bit_writer.write_bit(true).unwrap();
155 assert_eq!(bit_writer.bit_pos(), 1);
156 assert!(!bit_writer.is_aligned());
157
158 let err = bit_writer.write_bits(0b10000, 4).unwrap_err();
159 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
160 assert_eq!(err.to_string(), "bits too large to write");
161
162 assert_eq!(
163 bit_writer.finish().unwrap(),
164 vec![0b11111111, 0b00000000, 0b10101010, 0b10101010, 0b10000000]
165 );
166 }
167
168 #[test]
169 fn test_flush_buffer() {
170 let mut bit_writer = BitWriter::<Vec<u8>>::default();
171
172 bit_writer.write_bits(0b11111111, 8).unwrap();
173 assert_eq!(bit_writer.bit_pos(), 0);
174 assert!(bit_writer.is_aligned());
175 assert_eq!(bit_writer.get_ref(), &[0b11111111], "underlying writer should have one byte");
176
177 bit_writer.write_bits(0b0000, 4).unwrap();
178 assert_eq!(bit_writer.bit_pos(), 4);
179 assert!(!bit_writer.is_aligned());
180 assert_eq!(bit_writer.get_ref(), &[0b11111111], "underlying writer should have one bytes");
181
182 bit_writer.write_bits(0b1010, 4).unwrap();
183 assert_eq!(bit_writer.bit_pos(), 0);
184 assert!(bit_writer.is_aligned());
185 assert_eq!(
186 bit_writer.get_ref(),
187 &[0b11111111, 0b00001010],
188 "underlying writer should have two bytes"
189 );
190 }
191
192 #[test]
193 fn test_io_write() {
194 let mut inner = Vec::new();
195 let mut bit_writer = BitWriter::new(&mut inner);
196
197 bit_writer.write_bits(0b11111111, 8).unwrap();
198 assert_eq!(bit_writer.bit_pos(), 0);
199 assert!(bit_writer.is_aligned());
200 assert_eq!(bit_writer.get_ref().as_slice(), &[255]);
202
203 bit_writer.write_all(&[1, 2, 3]).unwrap();
204 assert_eq!(bit_writer.bit_pos(), 0);
205 assert!(bit_writer.is_aligned());
206 assert_eq!(bit_writer.get_ref().as_slice(), &[255, 1, 2, 3]);
210
211 bit_writer.write_bit(true).unwrap();
212
213 bit_writer.write_bits(0b1010, 4).unwrap();
214
215 bit_writer
216 .write_all(&[0b11111111, 0b00000000, 0b11111111, 0b00000000])
217 .unwrap();
218
219 assert_eq!(
221 bit_writer.get_ref().as_slice(),
222 &[255, 1, 2, 3, 0b11010111, 0b11111000, 0b00000111, 0b11111000]
223 );
224
225 bit_writer.finish().unwrap();
226
227 assert_eq!(
228 inner,
229 vec![255, 1, 2, 3, 0b11010111, 0b11111000, 0b00000111, 0b11111000, 0b00000000]
230 );
231 }
232
233 #[test]
234 fn test_flush() {
235 let mut inner = Vec::new();
236 let mut bit_writer = BitWriter::new(&mut inner);
237
238 bit_writer.write_bits(0b10100000, 8).unwrap();
239
240 bit_writer.flush().unwrap();
241
242 assert_eq!(bit_writer.get_ref().as_slice(), &[0b10100000]);
243 assert_eq!(bit_writer.bit_pos(), 0);
244 assert!(bit_writer.is_aligned());
245 }
246}