scuffle_rtmp/chunk/
reader.rs

1//! Types and functions for reading RTMP chunks.
2
3use std::cmp::min;
4use std::collections::HashMap;
5use std::io::{self, Cursor, Seek, SeekFrom};
6
7use byteorder::{BigEndian, LittleEndian, ReadBytesExt};
8use bytes::BytesMut;
9use num_traits::FromPrimitive;
10use scuffle_bytes_util::IoResultExt;
11
12use super::error::ChunkReadError;
13use super::{Chunk, ChunkBasicHeader, ChunkMessageHeader, ChunkType, INIT_CHUNK_SIZE, MAX_CHUNK_SIZE};
14use crate::messages::MessageType;
15
16// These constants are used to limit the amount of memory we use for partial
17// chunks on normal operations we should never hit these limits
18// This is for when someone is trying to send us a malicious chunk streams
19const MAX_PARTIAL_CHUNK_SIZE: usize = 10 * 1024 * 1024; // 10MB (should be more than enough)
20const MAX_PREVIOUS_CHUNK_HEADERS: usize = 100; // 100 chunks
21const MAX_PARTIAL_CHUNK_COUNT: usize = 4; // 4 chunks
22
23/// A chunk reader.
24///
25/// This is used to read chunks from a stream.
26pub struct ChunkReader {
27    /// According to the spec chunk streams are identified by the chunk stream
28    /// ID. In this case that is our key.
29    /// We then have a chunk header (since some chunks refer to the previous
30    /// chunk header)
31    previous_chunk_headers: HashMap<u32, ChunkMessageHeader>,
32
33    /// Technically according to the spec, we can have multiple message streams
34    /// in a single chunk stream. Because of this the key of this map is a tuple
35    /// (chunk stream id, message stream id).
36    partial_chunks: HashMap<(u32, u32), BytesMut>,
37
38    /// This is the max chunk size that the client has specified.
39    /// By default this is 128 bytes.
40    max_chunk_size: usize,
41}
42
43impl Default for ChunkReader {
44    fn default() -> Self {
45        Self {
46            previous_chunk_headers: HashMap::with_capacity(MAX_PREVIOUS_CHUNK_HEADERS),
47            partial_chunks: HashMap::with_capacity(MAX_PARTIAL_CHUNK_COUNT),
48            max_chunk_size: INIT_CHUNK_SIZE,
49        }
50    }
51}
52
53impl ChunkReader {
54    /// Call when a client requests a chunk size change.
55    ///
56    /// Returns `false` if the chunk size is out of bounds.
57    /// The connection should be closed in this case.
58    pub fn update_max_chunk_size(&mut self, chunk_size: usize) -> bool {
59        // We need to make sure that the chunk size is within the allowed range.
60        // Returning false here should close the connection.
61        if !(INIT_CHUNK_SIZE..=MAX_CHUNK_SIZE).contains(&chunk_size) {
62            false
63        } else {
64            self.max_chunk_size = chunk_size;
65            true
66        }
67    }
68
69    /// This function is used to read a chunk from the buffer.
70    ///
71    /// Returns:
72    /// - `Ok(None)` if the buffer does not contain enough data to read a full chunk.
73    /// - `Ok(Some(Chunk))` if a full chunk is read.
74    /// - `Err(ChunkReadError)` if there is an error decoding a chunk. The connection should be closed.
75    ///
76    /// # See also
77    ///
78    /// - [`Chunk`]
79    /// - [`ChunkReadError`]
80    pub fn read_chunk(&mut self, buffer: &mut BytesMut) -> Result<Option<Chunk>, crate::error::RtmpError> {
81        // We do this in a loop because we may have multiple chunks in the buffer,
82        // And those chunks may be partial chunks thus we need to keep reading until we
83        // have a full chunk or we run out of data.
84        loop {
85            // The cursor is an advanced cursor that is a reference to the buffer.
86            // This means the cursor does not advance the reader's position.
87            // Thus allowing us to backtrack if we need to read more data.
88            let mut cursor = std::io::Cursor::new(buffer.as_ref());
89
90            let Some(header) = self.read_header(&mut cursor)? else {
91                // Returning none here means that the buffer is empty and we need to wait for
92                // more data.
93                return Ok(None);
94            };
95
96            let Some(message_header) = self.read_message_header(&header, &mut cursor)? else {
97                // Returning none here means that the buffer is empty and we need to wait for
98                // more data.
99                return Ok(None);
100            };
101
102            let Some((payload_range_start, payload_range_end)) =
103                self.get_payload_range(&header, &message_header, &mut cursor)?
104            else {
105                // Returning none here means that the buffer is empty and we need to wait
106                // for more data.
107                return Ok(None);
108            };
109
110            // Since we were reading from an advanced cursor, our reads did not actually
111            // advance the reader's position. We need to manually advance the reader's
112            // position to the cursor's position.
113            let position = cursor.position() as usize;
114            if position > buffer.len() {
115                // In some cases we dont have enough data yet to read the chunk.
116                // We return Ok(None) here and the loop will continue.
117                return Ok(None);
118            }
119
120            let data = buffer.split_to(position);
121
122            // We freeze the chunk data and slice it to get the payload.
123            // Data before the slice is the header data, and data after the slice is the
124            // next chunk We don't need to keep the header data, because we already decoded
125            // it into struct form. The payload_range_end should be the same as the cursor's
126            // position.
127            let payload = data.freeze().slice(payload_range_start..payload_range_end);
128
129            // We need to check here if the chunk header is already stored in our map.
130            // This isnt a spec check but it is a check to make sure that we dont have too
131            // many previous chunk headers stored in memory.
132            let count = if self.previous_chunk_headers.contains_key(&header.chunk_stream_id) {
133                self.previous_chunk_headers.len()
134            } else {
135                self.previous_chunk_headers.len() + 1
136            };
137
138            // If this is hit, then we have too many previous chunk headers stored in
139            // memory. And the client is probably trying to DoS us.
140            // We return an error and the connection will be closed.
141            if count > MAX_PREVIOUS_CHUNK_HEADERS {
142                return Err(crate::error::RtmpError::ChunkRead(
143                    ChunkReadError::TooManyPreviousChunkHeaders,
144                ));
145            }
146
147            // We insert the chunk header into our map.
148            self.previous_chunk_headers
149                .insert(header.chunk_stream_id, message_header.clone());
150
151            // It is possible in theory to get a chunk message that requires us to change
152            // the max chunk size. However the size of that message is smaller than the
153            // default max chunk size. Therefore we can ignore this case.
154            // Since if we get such a message we will read it and the payload.len() will be
155            // equal to the message length. and thus we will return the chunk.
156
157            // Check if the payload is the same as the message length.
158            // If this is true we have a full chunk and we can return it.
159            if payload.len() == message_header.msg_length as usize {
160                return Ok(Some(Chunk {
161                    basic_header: header,
162                    message_header,
163                    payload,
164                }));
165            } else {
166                // Otherwise we generate a key using the chunk stream id and the message stream
167                // id. We then get the partial chunk from the map using the key.
168                let key = (header.chunk_stream_id, message_header.msg_stream_id);
169                let partial_chunk = match self.partial_chunks.get_mut(&key) {
170                    Some(partial_chunk) => partial_chunk,
171                    None => {
172                        // If it does not exists we create a new one.
173                        // If we have too many partial chunks we return an error.
174                        // Since the client is probably trying to DoS us.
175                        // The connection will be closed.
176                        if self.partial_chunks.len() >= MAX_PARTIAL_CHUNK_COUNT {
177                            return Err(crate::error::RtmpError::ChunkRead(ChunkReadError::TooManyPartialChunks));
178                        }
179
180                        // Insert a new empty BytesMut into the map.
181                        self.partial_chunks.insert(key, BytesMut::new());
182                        // Get the partial chunk we just inserted.
183                        self.partial_chunks.get_mut(&key).expect("we just inserted it")
184                    }
185                };
186
187                // We extend the partial chunk with the payload.
188                // And get the new length of the partial chunk.
189                let length = {
190                    // If the length of a single chunk is larger than the max partial chunk size
191                    // we return an error. The client is probably trying to DoS us.
192                    if partial_chunk.len() + payload.len() > MAX_PARTIAL_CHUNK_SIZE {
193                        return Err(crate::error::RtmpError::ChunkRead(ChunkReadError::PartialChunkTooLarge(
194                            partial_chunk.len() + payload.len(),
195                        )));
196                    }
197
198                    // Extend the partial chunk with the payload.
199                    partial_chunk.extend_from_slice(&payload[..]);
200
201                    // Return the new length of the partial chunk.
202                    partial_chunk.len()
203                };
204
205                // If we have a full chunk we return it.
206                if length == message_header.msg_length as usize {
207                    return Ok(Some(Chunk {
208                        basic_header: header,
209                        message_header,
210                        payload: self.partial_chunks.remove(&key).unwrap().freeze(),
211                    }));
212                }
213
214                // If we don't have a full chunk we just let the loop continue.
215                // Usually this will result in returning Ok(None) from one of
216                // the above checks. However there is a edge case that we have
217                // enough data in our buffer to read the next chunk and the
218                // client is waiting for us to send a response. Meaning if we
219                // just return Ok(None) here We would deadlock the connection,
220                // and it will eventually timeout. So we need to loop again here
221                // to check if we have enough data to read the next chunk.
222            }
223        }
224    }
225
226    /// Internal function used to read the basic chunk header.
227    fn read_header(&self, cursor: &mut Cursor<&[u8]>) -> Result<Option<ChunkBasicHeader>, io::Error> {
228        // The first byte of the basic header is the format of the chunk and the stream
229        // id. Mapping the error to none means that this isn't a real error but we dont
230        // have enough data.
231        let Some(byte) = cursor.read_u8().eof_to_none()? else {
232            return Ok(None);
233        };
234        // The format is the first 2 bits of the byte. We shift the byte 6 bits to the
235        // right to get the format.
236        let format = (byte >> 6) & 0b00000011;
237
238        // We do not check that the format is valid.
239        // It should not be possible to get an invalid chunk type
240        // because, we bitshift the byte 6 bits to the right. Leaving 2 bits which can
241        // only be 0, 1 or 2 or 3 which is the only valid chunk types.
242        let format = ChunkType::from_u8(format).expect("unreachable");
243
244        // We then parse the chunk stream id.
245        let chunk_stream_id = match (byte & 0b00111111) as u32 {
246            // If the chunk stream id is 0 we read the next byte and add 64 to it.
247            0 => {
248                let Some(first_byte) = cursor.read_u8().eof_to_none()? else {
249                    return Ok(None);
250                };
251
252                64 + first_byte as u32
253            }
254            // If it is 1 we read the next 2 bytes and add 64 to it and multiply the 2nd byte by
255            // 256.
256            1 => {
257                let Some(first_byte) = cursor.read_u8().eof_to_none()? else {
258                    return Ok(None);
259                };
260                let Some(second_byte) = cursor.read_u8().eof_to_none()? else {
261                    return Ok(None);
262                };
263
264                64 + first_byte as u32 + second_byte as u32 * 256
265            }
266            // Any other value means that the chunk stream id is the value of the byte.
267            csid => csid,
268        };
269
270        // We then read the message header.
271        let header = ChunkBasicHeader { chunk_stream_id, format };
272
273        Ok(Some(header))
274    }
275
276    /// Internal function used to read the message header.
277    fn read_message_header(
278        &self,
279        header: &ChunkBasicHeader,
280        cursor: &mut Cursor<&[u8]>,
281    ) -> Result<Option<ChunkMessageHeader>, crate::error::RtmpError> {
282        // Each format has a different message header length.
283        match header.format {
284            // Type0 headers have the most information and can be compared to keyframes in video.
285            // They do not reference any previous chunks. They contain the full message header.
286            ChunkType::Type0 => {
287                // The first 3 bytes are the timestamp.
288                let Some(timestamp) = cursor.read_u24::<BigEndian>().eof_to_none()? else {
289                    return Ok(None);
290                };
291                // Followed by a 3 byte message length. (this is the length of the entire
292                // payload not just this chunk)
293                let Some(msg_length) = cursor.read_u24::<BigEndian>().eof_to_none()? else {
294                    return Ok(None);
295                };
296                if msg_length as usize > MAX_PARTIAL_CHUNK_SIZE {
297                    return Err(crate::error::RtmpError::ChunkRead(ChunkReadError::PartialChunkTooLarge(
298                        msg_length as usize,
299                    )));
300                }
301
302                // We then have a 1 byte message type id.
303                let Some(msg_type_id) = cursor.read_u8().eof_to_none()? else {
304                    return Ok(None);
305                };
306                let msg_type_id = MessageType::from(msg_type_id);
307
308                // We then read the message stream id. (According to spec this is stored in
309                // LittleEndian, no idea why.)
310                let Some(msg_stream_id) = cursor.read_u32::<LittleEndian>().eof_to_none()? else {
311                    return Ok(None);
312                };
313
314                // Sometimes the timestamp is larger than 3 bytes.
315                // If the timestamp is 0xFFFFFF we read the next 4 bytes as the timestamp.
316                // I am not exactly sure why they did it this way.
317                // Why not just use 3 bytes for the timestamp, and if the 3 bytes are set to
318                // 0xFFFFFF just read 1 additional byte and then shift it 24 bits.
319                // Like if timestamp == 0xFFFFFF { timestamp |= cursor.read_u8() << 24; }
320                // This would save 3 bytes in the header and would be more
321                // efficient but I guess the Spec writers are smarter than me.
322                let (timestamp, was_extended_timestamp) = if timestamp == 0xFFFFFF {
323                    let Some(timestamp) = cursor.read_u32::<BigEndian>().eof_to_none()? else {
324                        return Ok(None);
325                    };
326                    (timestamp, true)
327                } else {
328                    (timestamp, false)
329                };
330
331                Ok(Some(ChunkMessageHeader {
332                    timestamp,
333                    msg_length,
334                    msg_type_id,
335                    msg_stream_id,
336                    was_extended_timestamp,
337                }))
338            }
339            // For ChunkType 1 we have a delta timestamp, message length and message type id.
340            // The message stream id is the same as the previous chunk.
341            ChunkType::Type1 => {
342                // The first 3 bytes are the delta timestamp.
343                let Some(timestamp_delta) = cursor.read_u24::<BigEndian>().eof_to_none()? else {
344                    return Ok(None);
345                };
346                // Followed by a 3 byte message length. (this is the length of the entire
347                // payload not just this chunk)
348                let Some(msg_length) = cursor.read_u24::<BigEndian>().eof_to_none()? else {
349                    return Ok(None);
350                };
351                if msg_length as usize > MAX_PARTIAL_CHUNK_SIZE {
352                    return Err(crate::error::RtmpError::ChunkRead(ChunkReadError::PartialChunkTooLarge(
353                        msg_length as usize,
354                    )));
355                }
356
357                // We then have a 1 byte message type id.
358                let Some(msg_type_id) = cursor.read_u8().eof_to_none()? else {
359                    return Ok(None);
360                };
361                let msg_type_id = MessageType::from(msg_type_id);
362
363                // Again as mentioned above we sometimes have a delta timestamp larger than 3
364                // bytes.
365                let (timestamp_delta, was_extended_timestamp) = if timestamp_delta == 0xFFFFFF {
366                    let Some(timestamp) = cursor.read_u32::<BigEndian>().eof_to_none()? else {
367                        return Ok(None);
368                    };
369                    (timestamp, true)
370                } else {
371                    (timestamp_delta, false)
372                };
373
374                // We get the previous chunk header.
375                // If the previous chunk header is not found we return an error. (this is a real
376                // error)
377                let previous_header =
378                    self.previous_chunk_headers
379                        .get(&header.chunk_stream_id)
380                        .ok_or(crate::error::RtmpError::ChunkRead(
381                            ChunkReadError::MissingPreviousChunkHeader(header.chunk_stream_id),
382                        ))?;
383
384                // We calculate the timestamp by adding the delta timestamp to the previous
385                // timestamp. We need to make sure this does not overflow.
386                let timestamp = previous_header.timestamp.checked_add(timestamp_delta).unwrap_or_else(|| {
387                    tracing::warn!(
388						"Timestamp overflow detected. Previous timestamp: {}, delta timestamp: {}, using previous timestamp.",
389						previous_header.timestamp,
390						timestamp_delta
391					);
392
393                    previous_header.timestamp
394                });
395
396                Ok(Some(ChunkMessageHeader {
397                    timestamp,
398                    msg_length,
399                    msg_type_id,
400                    was_extended_timestamp,
401                    // The message stream id is the same as the previous chunk.
402                    msg_stream_id: previous_header.msg_stream_id,
403                }))
404            }
405            // ChunkType2 headers only have a delta timestamp.
406            // The message length, message type id and message stream id are the same as the
407            // previous chunk.
408            ChunkType::Type2 => {
409                // We read the delta timestamp.
410                let Some(timestamp_delta) = cursor.read_u24::<BigEndian>().eof_to_none()? else {
411                    return Ok(None);
412                };
413
414                // Again if the delta timestamp is larger than 3 bytes we read the next 4 bytes
415                // as the timestamp.
416                let (timestamp_delta, was_extended_timestamp) = if timestamp_delta == 0xFFFFFF {
417                    let Some(timestamp) = cursor.read_u32::<BigEndian>().eof_to_none()? else {
418                        return Ok(None);
419                    };
420                    (timestamp, true)
421                } else {
422                    (timestamp_delta, false)
423                };
424
425                // We get the previous chunk header.
426                // If the previous chunk header is not found we return an error. (this is a real
427                // error)
428                let previous_header =
429                    self.previous_chunk_headers
430                        .get(&header.chunk_stream_id)
431                        .ok_or(crate::error::RtmpError::ChunkRead(
432                            ChunkReadError::MissingPreviousChunkHeader(header.chunk_stream_id),
433                        ))?;
434
435                // We calculate the timestamp by adding the delta timestamp to the previous
436                // timestamp.
437                let timestamp = previous_header.timestamp + timestamp_delta;
438
439                Ok(Some(ChunkMessageHeader {
440                    timestamp,
441                    msg_length: previous_header.msg_length,
442                    msg_type_id: previous_header.msg_type_id,
443                    msg_stream_id: previous_header.msg_stream_id,
444                    was_extended_timestamp,
445                }))
446            }
447            // ChunkType3 headers are the same as the previous chunk header.
448            ChunkType::Type3 => {
449                // We get the previous chunk header.
450                // If the previous chunk header is not found we return an error. (this is a real
451                // error)
452                let previous_header = self
453                    .previous_chunk_headers
454                    .get(&header.chunk_stream_id)
455                    .ok_or(crate::error::RtmpError::ChunkRead(
456                        ChunkReadError::MissingPreviousChunkHeader(header.chunk_stream_id),
457                    ))?
458                    .clone();
459
460                // Now this is truely stupid.
461                // If the PREVIOUS HEADER is extended then we now waste an additional 4 bytes to
462                // read the timestamp. Why not just read the timestamp in the previous header if
463                // it is extended? I guess the spec writers had some reason and its obviously
464                // way above my knowledge.
465                if previous_header.was_extended_timestamp {
466                    // Not a real error, we just dont have enough data.
467                    // We dont have to store this value since it is the same as the previous header.
468                    if cursor.read_u32::<BigEndian>().eof_to_none()?.is_none() {
469                        return Ok(None);
470                    };
471                }
472
473                Ok(Some(previous_header))
474            }
475        }
476    }
477
478    /// Internal function to get the payload range of a chunk.
479    fn get_payload_range(
480        &self,
481        header: &ChunkBasicHeader,
482        message_header: &ChunkMessageHeader,
483        cursor: &mut Cursor<&'_ [u8]>,
484    ) -> Result<Option<(usize, usize)>, crate::error::RtmpError> {
485        // We find out if the chunk is a partial chunk (and if we have already read some
486        // of it).
487        let key = (header.chunk_stream_id, message_header.msg_stream_id);
488
489        // Check how much we still need to read (if we have already read some of the
490        // chunk)
491        let remaining_read_length =
492            message_header.msg_length as usize - self.partial_chunks.get(&key).map(|data| data.len()).unwrap_or(0);
493
494        // We get the min between our max chunk size and the remaining read length.
495        // This is the amount of bytes we need to read.
496        let need_read_length = min(remaining_read_length, self.max_chunk_size);
497
498        // We get the current position in the cursor.
499        let pos = cursor.position() as usize;
500
501        // We seek forward to where the payload starts.
502        if cursor
503            .seek(SeekFrom::Current(need_read_length as i64))
504            .eof_to_none()?
505            .is_none()
506        {
507            return Ok(None);
508        };
509
510        // We then return the range of the payload.
511        // Which would be the pos to the pos + need_read_length.
512        Ok(Some((pos, pos + need_read_length)))
513    }
514}
515
516#[cfg(test)]
517#[cfg_attr(all(test, coverage_nightly), coverage(off))]
518mod tests {
519    use byteorder::WriteBytesExt;
520    use bytes::{BufMut, BytesMut};
521
522    use super::*;
523
524    #[test]
525    fn test_reader_error_display() {
526        let error = ChunkReadError::MissingPreviousChunkHeader(123);
527        assert_eq!(format!("{error}"), "missing previous chunk header: 123");
528
529        let error = ChunkReadError::TooManyPartialChunks;
530        assert_eq!(format!("{error}"), "too many partial chunks");
531
532        let error = ChunkReadError::TooManyPreviousChunkHeaders;
533        assert_eq!(format!("{error}"), "too many previous chunk headers");
534
535        let error = ChunkReadError::PartialChunkTooLarge(100);
536        assert_eq!(format!("{error}"), "partial chunk too large: 100");
537    }
538
539    #[test]
540    fn test_reader_chunk_size_out_of_bounds() {
541        let mut reader = ChunkReader::default();
542        assert!(!reader.update_max_chunk_size(MAX_CHUNK_SIZE + 1));
543    }
544
545    #[test]
546    fn test_incomplete_header() {
547        let mut buf = BytesMut::new();
548        buf.extend_from_slice(&[0b00_000000]);
549
550        let reader = ChunkReader::default();
551        let err = reader.read_header(&mut Cursor::new(&buf));
552        assert!(matches!(err, Ok(None)));
553    }
554
555    #[test]
556    fn test_reader_chunk_type0_single_sized() {
557        let mut buf = BytesMut::new();
558
559        #[rustfmt::skip]
560        buf.extend_from_slice(&[
561            3, // chunk type 0, chunk stream id 3
562            0x00, 0x00, 0x00, // timestamp
563            0x00, 0x00, 0x80, // message length (128) (max chunk size is set to 128)
564            0x09, // message type id (video)
565            0x00, 0x01, 0x00, 0x00, // message stream id
566        ]);
567
568        for i in 0..128 {
569            (&mut buf).writer().write_u8(i as u8).unwrap();
570        }
571
572        let mut unpacker = ChunkReader::default();
573        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
574        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
575        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
576        assert_eq!(chunk.message_header.timestamp, 0);
577        assert_eq!(chunk.message_header.msg_length, 128);
578        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
579        assert_eq!(chunk.payload.len(), 128);
580    }
581
582    #[test]
583    fn test_reader_chunk_type0_double_sized() {
584        let mut buf = BytesMut::new();
585        #[rustfmt::skip]
586        buf.extend_from_slice(&[
587            3, // chunk type 0, chunk stream id 3
588            0x00, 0x00, 0x00, // timestamp
589            0x00, 0x01, 0x00, // message length (256) (max chunk size is set to 128)
590            0x09, // message type id (video)
591            0x00, 0x01, 0x00, 0x00, // message stream id
592        ]);
593
594        for i in 0..128 {
595            (&mut buf).writer().write_u8(i as u8).unwrap();
596        }
597
598        let mut unpacker = ChunkReader::default();
599
600        let chunk = buf.as_ref().to_vec();
601
602        // We should not have enough data to read the chunk
603        // But the chunk is valid, so we should not get an error
604        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
605
606        // We just feed the same data again in this test to see if the Unpacker merges
607        // the chunks Which it should do
608        buf.extend_from_slice(&chunk);
609
610        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
611
612        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
613        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
614        assert_eq!(chunk.message_header.timestamp, 0);
615        assert_eq!(chunk.message_header.msg_length, 256);
616        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
617        assert_eq!(chunk.payload.len(), 256);
618    }
619
620    #[test]
621    fn test_reader_chunk_mutli_streams() {
622        let mut buf = BytesMut::new();
623
624        #[rustfmt::skip]
625        buf.extend_from_slice(&[
626            3, // chunk type 0, chunk stream id 3
627            0x00, 0x00, 0x00, // timestamp
628            0x00, 0x01, 0x00, // message length (256) (max chunk size is set to 128)
629            0x09, // message type id (video)
630            0x00, 0x01, 0x00, 0x00, // message stream id
631        ]);
632
633        for _ in 0..128 {
634            (&mut buf).writer().write_u8(3).unwrap();
635        }
636
637        #[rustfmt::skip]
638        buf.extend_from_slice(&[
639            4, // chunk type 0, chunk stream id 4 (different stream)
640            0x00, 0x00, 0x00, // timestamp
641            0x00, 0x01, 0x00, // message length (256) (max chunk size is set to 128)
642            0x08, // message type id (audio)
643            0x00, 0x03, 0x00, 0x00, // message stream id
644        ]);
645
646        for _ in 0..128 {
647            (&mut buf).writer().write_u8(4).unwrap();
648        }
649
650        let mut unpacker = ChunkReader::default();
651
652        // We wrote 2 chunks but neither of them are complete
653        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
654
655        #[rustfmt::skip]
656        buf.extend_from_slice(&[
657            (3 << 6) | 4, // chunk type 3, chunk stream id 4
658        ]);
659
660        for _ in 0..128 {
661            (&mut buf).writer().write_u8(3).unwrap();
662        }
663
664        // Even though we wrote chunk 3 first, chunk 4 should be read first since it's a
665        // different stream
666        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
667
668        assert_eq!(chunk.basic_header.chunk_stream_id, 4);
669        assert_eq!(chunk.message_header.msg_type_id.0, 0x08);
670        assert_eq!(chunk.message_header.timestamp, 0);
671        assert_eq!(chunk.message_header.msg_length, 256);
672        assert_eq!(chunk.message_header.msg_stream_id, 0x0300); // since it's little endian, it's 0x0100
673        assert_eq!(chunk.payload.len(), 256);
674        for i in 0..128 {
675            assert_eq!(chunk.payload[i], 4);
676        }
677
678        // No chunk is ready yet
679        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
680
681        #[rustfmt::skip]
682        buf.extend_from_slice(&[
683            (3 << 6) | 3, // chunk type 3, chunk stream id 3
684        ]);
685
686        for _ in 0..128 {
687            (&mut buf).writer().write_u8(3).unwrap();
688        }
689
690        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
691
692        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
693        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
694        assert_eq!(chunk.message_header.timestamp, 0);
695        assert_eq!(chunk.message_header.msg_length, 256);
696        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
697        assert_eq!(chunk.payload.len(), 256);
698        for i in 0..128 {
699            assert_eq!(chunk.payload[i], 3);
700        }
701    }
702
703    #[test]
704    fn test_reader_extended_timestamp() {
705        let mut buf = BytesMut::new();
706
707        #[rustfmt::skip]
708        buf.extend_from_slice(&[
709            3, // chunk type 0, chunk stream id 3
710            0xFF, 0xFF, 0xFF, // timestamp
711            0x00, 0x02, 0x00, // message length (384) (max chunk size is set to 128)
712            0x09, // message type id (video)
713            0x00, 0x01, 0x00, 0x00, // message stream id
714            0x01, 0x00, 0x00, 0x00, // extended timestamp
715        ]);
716
717        for i in 0..128 {
718            (&mut buf).writer().write_u8(i as u8).unwrap();
719        }
720
721        let mut unpacker = ChunkReader::default();
722
723        // We should not have enough data to read the chunk
724        // But the chunk is valid, so we should not get an error
725        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
726
727        #[rustfmt::skip]
728        buf.extend_from_slice(&[
729            (1 << 6) | 3, // chunk type 1, chunk stream id 3
730            0xFF, 0xFF, 0xFF, // extended timestamp (again)
731            0x00, 0x02, 0x00, // message length (384) (max chunk size is set to 128)
732            0x09, // message type id (video)
733            // message stream id is not present since it's the same as the previous chunk
734            0x01, 0x00, 0x00, 0x00, // extended timestamp (again)
735        ]);
736
737        for i in 0..128 {
738            (&mut buf).writer().write_u8(i as u8).unwrap();
739        }
740
741        #[rustfmt::skip]
742        buf.extend_from_slice(&[
743            (2 << 6) | 3, // chunk type 3, chunk stream id 3
744            0x00, 0x00, 0x01, // not extended timestamp
745        ]);
746
747        for i in 0..128 {
748            (&mut buf).writer().write_u8(i as u8).unwrap();
749        }
750
751        #[rustfmt::skip]
752        buf.extend_from_slice(&[
753            (3 << 6) | 3, // chunk type 3, chunk stream id 3
754        ]);
755
756        for i in 0..128 {
757            (&mut buf).writer().write_u8(i as u8).unwrap();
758        }
759
760        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
761
762        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
763        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
764        assert_eq!(chunk.message_header.timestamp, 0x02000001);
765        assert_eq!(chunk.message_header.msg_length, 512);
766        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
767        assert_eq!(chunk.payload.len(), 512);
768    }
769
770    #[test]
771    fn test_reader_extended_timestamp_ext() {
772        let mut buf = BytesMut::new();
773
774        #[rustfmt::skip]
775        buf.extend_from_slice(&[
776            3, // chunk type 0, chunk stream id 3
777            0xFF, 0xFF, 0xFF, // timestamp
778            0x00, 0x01, 0x00, // message length (256) (max chunk size is set to 128)
779            0x09, // message type id (video)
780            0x00, 0x01, 0x00, 0x00, // message stream id
781            0x01, 0x00, 0x00, 0x00, // extended timestamp
782        ]);
783
784        for i in 0..128 {
785            (&mut buf).writer().write_u8(i as u8).unwrap();
786        }
787
788        let mut unpacker = ChunkReader::default();
789
790        // We should not have enough data to read the chunk
791        // But the chunk is valid, so we should not get an error
792        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
793
794        #[rustfmt::skip]
795        buf.extend_from_slice(&[
796            (3 << 6) | 3, // chunk type 1, chunk stream id 3
797            0x00, 0x00, 0x00, 0x00, // extended timestamp this value is ignored
798        ]);
799
800        for i in 0..128 {
801            (&mut buf).writer().write_u8(i as u8).unwrap();
802        }
803
804        for i in 0..128 {
805            (&mut buf).writer().write_u8(i as u8).unwrap();
806        }
807
808        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
809
810        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
811        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
812        assert_eq!(chunk.message_header.timestamp, 0x01000000);
813        assert_eq!(chunk.message_header.msg_length, 256);
814        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
815        assert_eq!(chunk.payload.len(), 256);
816    }
817
818    #[test]
819    fn test_read_extended_csid() {
820        let mut buf = BytesMut::new();
821
822        #[rustfmt::skip]
823        buf.extend_from_slice(&[
824            (0 << 6), // chunk type 0, chunk stream id 0
825            10,       // extended chunk stream id
826            0x00, 0x00, 0x00, // timestamp
827            0x00, 0x00, 0x00, // message length (256) (max chunk size is set to 128)
828            0x09, // message type id (video)
829            0x00, 0x01, 0x00, 0x00, // message stream id
830        ]);
831
832        let mut unpacker = ChunkReader::default();
833        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
834
835        assert_eq!(chunk.basic_header.chunk_stream_id, 64 + 10);
836    }
837
838    #[test]
839    fn test_read_extended_csid_ext2() {
840        let mut buf = BytesMut::new();
841
842        #[rustfmt::skip]
843        buf.extend_from_slice(&[
844            1,  // chunk type 0, chunk stream id 0
845            10, // extended chunk stream id
846            13, // extended chunk stream id 2
847            0x00, 0x00, 0x00, // timestamp
848            0x00, 0x00, 0x00, // message length (256) (max chunk size is set to 128)
849            0x09, // message type id (video)
850            0x00, 0x01, 0x00, 0x00, // message stream id
851        ]);
852
853        let mut unpacker = ChunkReader::default();
854
855        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
856
857        assert_eq!(chunk.basic_header.chunk_stream_id, 64 + 10 + 256 * 13);
858    }
859
860    #[test]
861    fn test_reader_error_no_previous_chunk() {
862        let mut buf = BytesMut::new();
863
864        // Write a chunk with type 3 but no previous chunk
865        #[rustfmt::skip]
866        buf.extend_from_slice(&[
867            (3 << 6) | 3, // chunk type 0, chunk stream id 3
868        ]);
869
870        let mut unpacker = ChunkReader::default();
871        let err = unpacker.read_chunk(&mut buf).unwrap_err();
872        match err {
873            crate::error::RtmpError::ChunkRead(ChunkReadError::MissingPreviousChunkHeader(3)) => {}
874            _ => panic!("Unexpected error: {err:?}"),
875        }
876    }
877
878    #[test]
879    fn test_reader_error_partial_chunk_too_large() {
880        let mut buf = BytesMut::new();
881
882        // Write a chunk that has a message size that is too large
883        #[rustfmt::skip]
884        buf.extend_from_slice(&[
885            3, // chunk type 0, chunk stream id 3
886            0xFF, 0xFF, 0xFF, // timestamp
887            0xFF, 0xFF, 0xFF, // message length (max chunk size is set to 128)
888            0x09, // message type id (video)
889            0x00, 0x01, 0x00, 0x00, // message stream id
890            0x01, 0x00, 0x00, 0x00, // extended timestamp
891        ]);
892
893        let mut unpacker = ChunkReader::default();
894
895        let err = unpacker.read_chunk(&mut buf).unwrap_err();
896        match err {
897            crate::error::RtmpError::ChunkRead(ChunkReadError::PartialChunkTooLarge(16777215)) => {}
898            _ => panic!("Unexpected error: {err:?}"),
899        }
900    }
901
902    #[test]
903    fn test_reader_error_too_many_partial_chunks() {
904        let mut buf = BytesMut::new();
905
906        let mut unpacker = ChunkReader::default();
907
908        for i in 0..4 {
909            // Write another chunk with a different chunk stream id
910            #[rustfmt::skip]
911            buf.extend_from_slice(&[
912                (i + 2), // chunk type 0 (partial), chunk stream id i
913                0xFF, 0xFF, 0xFF, // timestamp
914                0x00, 0x01, 0x00, // message length (max chunk size is set to 128)
915                0x09, // message type id (video)
916                0x00, 0x01, 0x00, 0x00, // message stream id
917                0x01, 0x00, 0x00, 0x00, // extended timestamp
918            ]);
919
920            for i in 0..128 {
921                (&mut buf).writer().write_u8(i as u8).unwrap();
922            }
923
924            // Read the chunk
925            assert!(
926                unpacker
927                    .read_chunk(&mut buf)
928                    .unwrap_or_else(|_| panic!("chunk failed {i}"))
929                    .is_none()
930            );
931        }
932
933        // Write another chunk with a different chunk stream id
934        #[rustfmt::skip]
935        buf.extend_from_slice(&[
936            12, // chunk type 0, chunk stream id 6
937            0xFF, 0xFF, 0xFF, // timestamp
938            0x00, 0x01, 0x00, // message length (max chunk size is set to 128)
939            0x09, // message type id (video)
940            0x00, 0x01, 0x00, 0x00, // message stream id
941            0x01, 0x00, 0x00, 0x00, // extended timestamp
942        ]);
943
944        for i in 0..128 {
945            (&mut buf).writer().write_u8(i as u8).unwrap();
946        }
947
948        let err = unpacker.read_chunk(&mut buf).unwrap_err();
949        match err {
950            crate::error::RtmpError::ChunkRead(ChunkReadError::TooManyPartialChunks) => {}
951            _ => panic!("Unexpected error: {err:?}"),
952        }
953    }
954
955    #[test]
956    fn test_reader_error_too_many_chunk_headers() {
957        let mut buf = BytesMut::new();
958
959        let mut unpacker = ChunkReader::default();
960
961        for i in 0..100 {
962            // Write another chunk with a different chunk stream id
963            #[rustfmt::skip]
964            buf.extend_from_slice(&[
965                (0 << 6), // chunk type 0 (partial), chunk stream id 0
966                i,        // chunk id
967                0xFF, 0xFF, 0xFF, // timestamp
968                0x00, 0x00, 0x00, // message length (max chunk size is set to 128)
969                0x09, // message type id (video)
970                0x00, 0x01, 0x00, 0x00, // message stream id
971                0x01, 0x00, 0x00, 0x00, // extended timestamp
972            ]);
973
974            // Read the chunk (should be a full chunk since the message length is 0)
975            assert!(
976                unpacker
977                    .read_chunk(&mut buf)
978                    .unwrap_or_else(|_| panic!("chunk failed {i}"))
979                    .is_some()
980            );
981        }
982
983        // Write another chunk with a different chunk stream id
984        #[rustfmt::skip]
985        buf.extend_from_slice(&[
986            12, // chunk type 0, chunk stream id 6
987            0xFF, 0xFF, 0xFF, // timestamp
988            0x00, 0x00, 0x00, // message length (max chunk size is set to 128)
989            0x09, // message type id (video)
990            0x00, 0x01, 0x00, 0x00, // message stream id
991            0x01, 0x00, 0x00, 0x00, // extended timestamp
992        ]);
993
994        let err = unpacker.read_chunk(&mut buf).unwrap_err();
995        match err {
996            crate::error::RtmpError::ChunkRead(ChunkReadError::TooManyPreviousChunkHeaders) => {}
997            _ => panic!("Unexpected error: {err:?}"),
998        }
999    }
1000
1001    #[test]
1002    fn test_reader_larger_chunk_size() {
1003        let mut buf = BytesMut::new();
1004
1005        // Write a chunk that has a message size that is too large
1006        #[rustfmt::skip]
1007        buf.extend_from_slice(&[
1008            3, // chunk type 0, chunk stream id 3
1009            0x00, 0x00, 0xFF, // timestamp
1010            0x00, 0x0F, 0x00, // message length ()
1011            0x09, // message type id (video)
1012            0x01, 0x00, 0x00, 0x00, // message stream id
1013        ]);
1014
1015        for i in 0..3840 {
1016            (&mut buf).writer().write_u8(i as u8).unwrap();
1017        }
1018
1019        let mut unpacker = ChunkReader::default();
1020        unpacker.update_max_chunk_size(4096);
1021
1022        let chunk = unpacker.read_chunk(&mut buf).expect("failed").expect("chunk");
1023        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
1024        assert_eq!(chunk.message_header.timestamp, 255);
1025        assert_eq!(chunk.message_header.msg_length, 3840);
1026        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
1027        assert_eq!(chunk.message_header.msg_stream_id, 1); // little endian
1028        assert_eq!(chunk.payload.len(), 3840);
1029
1030        for i in 0..3840 {
1031            assert_eq!(chunk.payload[i], i as u8);
1032        }
1033    }
1034}