from typing import List, Optional, Tuple from .rtp import RtpPacket from .utils import uint16_add MAX_MISORDER = 100 class JitterFrame: def __init__(self, data: bytes, timestamp: int) -> None: self.data = data self.timestamp = timestamp class JitterBuffer: def __init__( self, capacity: int, prefetch: int = 0, is_video: bool = False ) -> None: assert capacity & (capacity - 1) == 0, "capacity must be a power of 2" self._capacity = capacity self._origin: Optional[int] = None self._packets: List[Optional[RtpPacket]] = [None for i in range(capacity)] self._prefetch = prefetch self._is_video = is_video @property def capacity(self) -> int: return self._capacity def add(self, packet: RtpPacket) -> Tuple[bool, Optional[JitterFrame]]: pli_flag = False if self._origin is None: self._origin = packet.sequence_number delta = 0 misorder = 0 else: delta = uint16_add(packet.sequence_number, -self._origin) misorder = uint16_add(self._origin, -packet.sequence_number) if misorder < delta: if misorder >= MAX_MISORDER: self.remove(self.capacity) self._origin = packet.sequence_number delta = misorder = 0 if self._is_video: pli_flag = True else: return pli_flag, None if delta >= self.capacity: # remove just enough frames to fit the received packets excess = delta - self.capacity + 1 if self.smart_remove(excess): self._origin = packet.sequence_number if self._is_video: pli_flag = True pos = packet.sequence_number % self._capacity self._packets[pos] = packet return pli_flag, self._remove_frame(packet.sequence_number) def _remove_frame(self, sequence_number: int) -> Optional[JitterFrame]: frame = None frames = 0 packets = [] remove = 0 timestamp = None for count in range(self.capacity): pos = (self._origin + count) % self._capacity packet = self._packets[pos] if packet is None: break if timestamp is None: timestamp = packet.timestamp elif packet.timestamp != timestamp: # we now have a complete frame, only store the first one if frame is None: frame = JitterFrame( data=b"".join([x._data for x in packets]), timestamp=timestamp ) remove = count # check we have prefetched enough frames += 1 if frames >= self._prefetch: self.remove(remove) return frame # start a new frame packets = [] timestamp = packet.timestamp packets.append(packet) return None def remove(self, count: int) -> None: assert count <= self._capacity for i in range(count): pos = self._origin % self._capacity self._packets[pos] = None self._origin = uint16_add(self._origin, 1) def smart_remove(self, count: int) -> bool: """ Makes sure that all packages belonging to the same frame are removed to prevent sending corrupted frames to the decoder. """ timestamp = None for i in range(self._capacity): pos = self._origin % self._capacity packet = self._packets[pos] if packet is not None: if i >= count and timestamp != packet.timestamp: break timestamp = packet.timestamp self._packets[pos] = None self._origin = uint16_add(self._origin, 1) if i == self._capacity - 1: return True return False