import math import os import struct from dataclasses import dataclass, field from struct import pack, unpack, unpack_from from typing import Any, List, Optional, Tuple, Union from av import AudioFrame from .rtcrtpparameters import RTCRtpParameters # used for NACK and retransmission RTP_HISTORY_SIZE = 128 # reserved to avoid confusion with RTCP FORBIDDEN_PAYLOAD_TYPES = range(72, 77) DYNAMIC_PAYLOAD_TYPES = range(96, 128) RTP_HEADER_LENGTH = 12 RTCP_HEADER_LENGTH = 4 PACKETS_LOST_MIN = -(1 << 23) PACKETS_LOST_MAX = (1 << 23) - 1 RTCP_SR = 200 RTCP_RR = 201 RTCP_SDES = 202 RTCP_BYE = 203 RTCP_RTPFB = 205 RTCP_PSFB = 206 RTCP_RTPFB_NACK = 1 RTCP_PSFB_PLI = 1 RTCP_PSFB_SLI = 2 RTCP_PSFB_RPSI = 3 RTCP_PSFB_APP = 15 @dataclass class HeaderExtensions: abs_send_time: Optional[int] = None audio_level: Any = None mid: Any = None repaired_rtp_stream_id: Any = None rtp_stream_id: Any = None transmission_offset: Optional[int] = None transport_sequence_number: Optional[int] = None class HeaderExtensionsMap: def __init__(self) -> None: self.__ids = HeaderExtensions() def configure(self, parameters: RTCRtpParameters) -> None: for ext in parameters.headerExtensions: if ext.uri == "urn:ietf:params:rtp-hdrext:sdes:mid": self.__ids.mid = ext.id elif ext.uri == "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id": self.__ids.repaired_rtp_stream_id = ext.id elif ext.uri == "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id": self.__ids.rtp_stream_id = ext.id elif ( ext.uri == "http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time" ): self.__ids.abs_send_time = ext.id elif ext.uri == "urn:ietf:params:rtp-hdrext:toffset": self.__ids.transmission_offset = ext.id elif ext.uri == "urn:ietf:params:rtp-hdrext:ssrc-audio-level": self.__ids.audio_level = ext.id elif ( ext.uri == "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" ): self.__ids.transport_sequence_number = ext.id def get(self, extension_profile: int, extension_value: bytes) -> HeaderExtensions: values = HeaderExtensions() for x_id, x_value in unpack_header_extensions( extension_profile, extension_value ): if x_id == self.__ids.mid: values.mid = x_value.decode("utf8") elif x_id == self.__ids.repaired_rtp_stream_id: values.repaired_rtp_stream_id = x_value.decode("ascii") elif x_id == self.__ids.rtp_stream_id: values.rtp_stream_id = x_value.decode("ascii") elif x_id == self.__ids.abs_send_time: values.abs_send_time = unpack("!L", b"\00" + x_value)[0] elif x_id == self.__ids.transmission_offset: values.transmission_offset = unpack("!l", x_value + b"\00")[0] >> 8 elif x_id == self.__ids.audio_level: vad_level = unpack("!B", x_value)[0] values.audio_level = (vad_level & 0x80 == 0x80, vad_level & 0x7F) elif x_id == self.__ids.transport_sequence_number: values.transport_sequence_number = unpack("!H", x_value)[0] return values def set(self, values: HeaderExtensions): extensions = [] if values.mid is not None and self.__ids.mid: extensions.append((self.__ids.mid, values.mid.encode("utf8"))) if ( values.repaired_rtp_stream_id is not None and self.__ids.repaired_rtp_stream_id ): extensions.append( ( self.__ids.repaired_rtp_stream_id, values.repaired_rtp_stream_id.encode("ascii"), ) ) if values.rtp_stream_id is not None and self.__ids.rtp_stream_id: extensions.append( (self.__ids.rtp_stream_id, values.rtp_stream_id.encode("ascii")) ) if values.abs_send_time is not None and self.__ids.abs_send_time: extensions.append( (self.__ids.abs_send_time, pack("!L", values.abs_send_time)[1:]) ) if values.transmission_offset is not None and self.__ids.transmission_offset: extensions.append( ( self.__ids.transmission_offset, pack("!l", values.transmission_offset << 8)[0:2], ) ) if values.audio_level is not None and self.__ids.audio_level: extensions.append( ( self.__ids.audio_level, pack( "!B", (0x80 if values.audio_level[0] else 0) | (values.audio_level[1] & 0x7F), ), ) ) if ( values.transport_sequence_number is not None and self.__ids.transport_sequence_number ): extensions.append( ( self.__ids.transport_sequence_number, pack("!H", values.transport_sequence_number), ) ) return pack_header_extensions(extensions) def clamp_packets_lost(count: int) -> int: return max(PACKETS_LOST_MIN, min(count, PACKETS_LOST_MAX)) def pack_packets_lost(count: int) -> bytes: return pack("!l", count)[1:] def unpack_packets_lost(d: bytes) -> int: if d[0] & 0x80: d = b"\xff" + d else: d = b"\x00" + d return unpack("!l", d)[0] def pack_rtcp_packet(packet_type: int, count: int, payload: bytes) -> bytes: assert len(payload) % 4 == 0 return pack("!BBH", (2 << 6) | count, packet_type, len(payload) // 4) + payload def pack_remb_fci(bitrate: int, ssrcs: List[int]) -> bytes: """ Pack the FCI for a Receiver Estimated Maximum Bitrate report. https://tools.ietf.org/html/draft-alvestrand-rmcat-remb-03 """ data = b"REMB" exponent = 0 mantissa = bitrate while mantissa > 0x3FFFF: mantissa >>= 1 exponent += 1 data += pack( "!BBH", len(ssrcs), (exponent << 2) | (mantissa >> 16), (mantissa & 0xFFFF) ) for ssrc in ssrcs: data += pack("!L", ssrc) return data def unpack_remb_fci(data: bytes) -> Tuple[int, List[int]]: """ Unpack the FCI for a Receiver Estimated Maximum Bitrate report. https://tools.ietf.org/html/draft-alvestrand-rmcat-remb-03 """ if len(data) < 8 or data[0:4] != b"REMB": raise ValueError("Invalid REMB prefix") exponent = (data[5] & 0xFC) >> 2 mantissa = ((data[5] & 0x03) << 16) | (data[6] << 8) | data[7] bitrate = mantissa << exponent pos = 8 ssrcs = [] for r in range(data[4]): ssrcs.append(unpack_from("!L", data, pos)[0]) pos += 4 return (bitrate, ssrcs) def is_rtcp(msg: bytes) -> bool: return len(msg) >= 2 and msg[1] >= 192 and msg[1] <= 208 def padl(length: int) -> int: """ Return amount of padding needed for a 4-byte multiple. """ return 4 * ((length + 3) // 4) - length def unpack_header_extensions( extension_profile: int, extension_value: bytes ) -> List[Tuple[int, bytes]]: """ Parse header extensions according to RFC 5285. """ extensions = [] pos = 0 if extension_profile == 0xBEDE: # One-Byte Header while pos < len(extension_value): # skip padding byte if extension_value[pos] == 0: pos += 1 continue x_id = (extension_value[pos] & 0xF0) >> 4 x_length = (extension_value[pos] & 0x0F) + 1 pos += 1 if len(extension_value) < pos + x_length: raise ValueError("RTP one-byte header extension value is truncated") x_value = extension_value[pos : pos + x_length] extensions.append((x_id, x_value)) pos += x_length elif extension_profile == 0x1000: # Two-Byte Header while pos < len(extension_value): # skip padding byte if extension_value[pos] == 0: pos += 1 continue if len(extension_value) < pos + 2: raise ValueError("RTP two-byte header extension is truncated") x_id, x_length = unpack_from("!BB", extension_value, pos) pos += 2 if len(extension_value) < pos + x_length: raise ValueError("RTP two-byte header extension value is truncated") x_value = extension_value[pos : pos + x_length] extensions.append((x_id, x_value)) pos += x_length return extensions def pack_header_extensions(extensions: List[Tuple[int, bytes]]) -> Tuple[int, bytes]: """ Serialize header extensions according to RFC 5285. """ extension_profile = 0 extension_value = b"" if not extensions: return extension_profile, extension_value one_byte = True for x_id, x_value in extensions: x_length = len(x_value) assert x_id > 0 and x_id < 256 assert x_length >= 0 and x_length < 256 if x_id > 14 or x_length == 0 or x_length > 16: one_byte = False if one_byte: # One-Byte Header extension_profile = 0xBEDE extension_value = b"" for x_id, x_value in extensions: x_length = len(x_value) extension_value += pack("!B", (x_id << 4) | (x_length - 1)) extension_value += x_value else: # Two-Byte Header extension_profile = 0x1000 extension_value = b"" for x_id, x_value in extensions: x_length = len(x_value) extension_value += pack("!BB", x_id, x_length) extension_value += x_value extension_value += b"\x00" * padl(len(extension_value)) return extension_profile, extension_value def compute_audio_level_dbov(frame: AudioFrame) -> int: """ Compute the energy level as spelled out in RFC 6465, Appendix A. """ MAX_SAMPLE_VALUE = 32767 MAX_AUDIO_LEVEL = 0 MIN_AUDIO_LEVEL = -127 rms = 0.0 buf = bytes(frame.planes[0]) s = struct.Struct("h") for unpacked in s.iter_unpack(buf): sample = unpacked[0] rms += sample * sample rms = math.sqrt(rms / (frame.samples * MAX_SAMPLE_VALUE * MAX_SAMPLE_VALUE)) if rms > 0: db = 20 * math.log10(rms) db = max(db, MIN_AUDIO_LEVEL) db = min(db, MAX_AUDIO_LEVEL) else: db = MIN_AUDIO_LEVEL return round(db) @dataclass class RtcpReceiverInfo: ssrc: int fraction_lost: int packets_lost: int highest_sequence: int jitter: int lsr: int dlsr: int def __bytes__(self) -> bytes: data = pack("!LB", self.ssrc, self.fraction_lost) data += pack_packets_lost(self.packets_lost) data += pack("!LLLL", self.highest_sequence, self.jitter, self.lsr, self.dlsr) return data @classmethod def parse(cls, data: bytes): ssrc, fraction_lost = unpack("!LB", data[0:5]) packets_lost = unpack_packets_lost(data[5:8]) highest_sequence, jitter, lsr, dlsr = unpack("!LLLL", data[8:]) return cls( ssrc=ssrc, fraction_lost=fraction_lost, packets_lost=packets_lost, highest_sequence=highest_sequence, jitter=jitter, lsr=lsr, dlsr=dlsr, ) @dataclass class RtcpSenderInfo: ntp_timestamp: int rtp_timestamp: int packet_count: int octet_count: int def __bytes__(self) -> bytes: return pack( "!QLLL", self.ntp_timestamp, self.rtp_timestamp, self.packet_count, self.octet_count, ) @classmethod def parse(cls, data: bytes): ntp_timestamp, rtp_timestamp, packet_count, octet_count = unpack("!QLLL", data) return cls( ntp_timestamp=ntp_timestamp, rtp_timestamp=rtp_timestamp, packet_count=packet_count, octet_count=octet_count, ) @dataclass class RtcpSourceInfo: ssrc: int items: List[Tuple[Any, bytes]] @dataclass class RtcpByePacket: sources: List[int] def __bytes__(self) -> bytes: payload = b"".join([pack("!L", ssrc) for ssrc in self.sources]) return pack_rtcp_packet(RTCP_BYE, len(self.sources), payload) @classmethod def parse(cls, data: bytes, count: int): if len(data) < 4 * count: raise ValueError("RTCP bye length is invalid") if count > 0: sources = list(unpack_from("!" + ("L" * count), data, 0)) else: sources = [] return cls(sources=sources) @dataclass class RtcpPsfbPacket: """ Payload-Specific Feedback Message (RFC 4585). """ fmt: int ssrc: int media_ssrc: int fci: bytes = b"" def __bytes__(self) -> bytes: payload = pack("!LL", self.ssrc, self.media_ssrc) + self.fci return pack_rtcp_packet(RTCP_PSFB, self.fmt, payload) @classmethod def parse(cls, data: bytes, fmt: int): if len(data) < 8: raise ValueError("RTCP payload-specific feedback length is invalid") ssrc, media_ssrc = unpack("!LL", data[0:8]) fci = data[8:] return cls(fmt=fmt, ssrc=ssrc, media_ssrc=media_ssrc, fci=fci) @dataclass class RtcpRrPacket: ssrc: int reports: List[RtcpReceiverInfo] = field(default_factory=list) def __bytes__(self) -> bytes: payload = pack("!L", self.ssrc) for report in self.reports: payload += bytes(report) return pack_rtcp_packet(RTCP_RR, len(self.reports), payload) @classmethod def parse(cls, data: bytes, count: int): if len(data) != 4 + 24 * count: raise ValueError("RTCP receiver report length is invalid") ssrc = unpack("!L", data[0:4])[0] pos = 4 reports = [] for r in range(count): reports.append(RtcpReceiverInfo.parse(data[pos : pos + 24])) pos += 24 return cls(ssrc=ssrc, reports=reports) @dataclass class RtcpRtpfbPacket: """ Generic RTP Feedback Message (RFC 4585). """ fmt: int ssrc: int media_ssrc: int # generick NACK lost: List[int] = field(default_factory=list) def __bytes__(self) -> bytes: payload = pack("!LL", self.ssrc, self.media_ssrc) if self.lost: pid = self.lost[0] blp = 0 for p in self.lost[1:]: d = p - pid - 1 if d < 16: blp |= 1 << d else: payload += pack("!HH", pid, blp) pid = p blp = 0 payload += pack("!HH", pid, blp) return pack_rtcp_packet(RTCP_RTPFB, self.fmt, payload) @classmethod def parse(cls, data: bytes, fmt: int): if len(data) < 8 or len(data) % 4: raise ValueError("RTCP RTP feedback length is invalid") ssrc, media_ssrc = unpack("!LL", data[0:8]) lost = [] for pos in range(8, len(data), 4): pid, blp = unpack("!HH", data[pos : pos + 4]) lost.append(pid) for d in range(0, 16): if (blp >> d) & 1: lost.append(pid + d + 1) return cls(fmt=fmt, ssrc=ssrc, media_ssrc=media_ssrc, lost=lost) @dataclass class RtcpSdesPacket: chunks: List[RtcpSourceInfo] = field(default_factory=list) def __bytes__(self) -> bytes: payload = b"" for chunk in self.chunks: payload += pack("!L", chunk.ssrc) for d_type, d_value in chunk.items: payload += pack("!BB", d_type, len(d_value)) + d_value payload += b"\x00\x00" while len(payload) % 4: payload += b"\x00" return pack_rtcp_packet(RTCP_SDES, len(self.chunks), payload) @classmethod def parse(cls, data: bytes, count: int): pos = 0 chunks = [] for r in range(count): if len(data) < pos + 4: raise ValueError("RTCP SDES source is truncated") ssrc = unpack_from("!L", data, pos)[0] pos += 4 items = [] while pos < len(data) - 1: d_type, d_length = unpack_from("!BB", data, pos) pos += 2 if len(data) < pos + d_length: raise ValueError("RTCP SDES item is truncated") d_value = data[pos : pos + d_length] pos += d_length if d_type == 0: break else: items.append((d_type, d_value)) chunks.append(RtcpSourceInfo(ssrc=ssrc, items=items)) return cls(chunks=chunks) @dataclass class RtcpSrPacket: ssrc: int sender_info: RtcpSenderInfo reports: List[RtcpReceiverInfo] = field(default_factory=list) def __bytes__(self) -> bytes: payload = pack("!L", self.ssrc) payload += bytes(self.sender_info) for report in self.reports: payload += bytes(report) return pack_rtcp_packet(RTCP_SR, len(self.reports), payload) @classmethod def parse(cls, data: bytes, count: int): if len(data) != 24 + 24 * count: raise ValueError("RTCP sender report length is invalid") ssrc = unpack_from("!L", data)[0] sender_info = RtcpSenderInfo.parse(data[4:24]) pos = 24 reports = [] for r in range(count): reports.append(RtcpReceiverInfo.parse(data[pos : pos + 24])) pos += 24 return RtcpSrPacket(ssrc=ssrc, sender_info=sender_info, reports=reports) AnyRtcpPacket = Union[ RtcpByePacket, RtcpPsfbPacket, RtcpRrPacket, RtcpRtpfbPacket, RtcpSdesPacket, RtcpSrPacket, ] class RtcpPacket: @classmethod def parse(cls, data: bytes) -> List[AnyRtcpPacket]: pos = 0 packets = [] while pos < len(data): if len(data) < pos + RTCP_HEADER_LENGTH: raise ValueError( f"RTCP packet length is less than {RTCP_HEADER_LENGTH} bytes" ) v_p_count, packet_type, length = unpack("!BBH", data[pos : pos + 4]) version = v_p_count >> 6 padding = (v_p_count >> 5) & 1 count = v_p_count & 0x1F if version != 2: raise ValueError("RTCP packet has invalid version") pos += 4 end = pos + length * 4 if len(data) < end: raise ValueError("RTCP packet is truncated") payload = data[pos:end] pos = end if padding: if not payload or not payload[-1] or payload[-1] > len(payload): raise ValueError("RTCP packet padding length is invalid") payload = payload[0 : -payload[-1]] if packet_type == RTCP_BYE: packets.append(RtcpByePacket.parse(payload, count)) elif packet_type == RTCP_SDES: packets.append(RtcpSdesPacket.parse(payload, count)) elif packet_type == RTCP_SR: packets.append(RtcpSrPacket.parse(payload, count)) elif packet_type == RTCP_RR: packets.append(RtcpRrPacket.parse(payload, count)) elif packet_type == RTCP_RTPFB: packets.append(RtcpRtpfbPacket.parse(payload, count)) elif packet_type == RTCP_PSFB: packets.append(RtcpPsfbPacket.parse(payload, count)) return packets class RtpPacket: def __init__( self, payload_type: int = 0, marker: int = 0, sequence_number: int = 0, timestamp: int = 0, ssrc: int = 0, payload: bytes = b"", ) -> None: self.version = 2 self.marker = marker self.payload_type = payload_type self.sequence_number = sequence_number self.timestamp = timestamp self.ssrc = ssrc self.csrc: List[int] = [] self.extensions = HeaderExtensions() self.payload = payload self.padding_size = 0 def __repr__(self) -> str: return ( f"RtpPacket(seq={self.sequence_number}, ts={self.timestamp}, " f"marker={self.marker}, payload={self.payload_type}, " f"{len(self.payload)} bytes)" ) @classmethod def parse(cls, data: bytes, extensions_map=HeaderExtensionsMap()): if len(data) < RTP_HEADER_LENGTH: raise ValueError( f"RTP packet length is less than {RTP_HEADER_LENGTH} bytes" ) v_p_x_cc, m_pt, sequence_number, timestamp, ssrc = unpack("!BBHLL", data[0:12]) version = v_p_x_cc >> 6 padding = (v_p_x_cc >> 5) & 1 extension = (v_p_x_cc >> 4) & 1 cc = v_p_x_cc & 0x0F if version != 2: raise ValueError("RTP packet has invalid version") if len(data) < RTP_HEADER_LENGTH + 4 * cc: raise ValueError("RTP packet has truncated CSRC") packet = cls( marker=(m_pt >> 7), payload_type=(m_pt & 0x7F), sequence_number=sequence_number, timestamp=timestamp, ssrc=ssrc, ) pos = RTP_HEADER_LENGTH for i in range(0, cc): packet.csrc.append(unpack_from("!L", data, pos)[0]) pos += 4 if extension: if len(data) < pos + 4: raise ValueError("RTP packet has truncated extension profile / length") extension_profile, extension_length = unpack_from("!HH", data, pos) extension_length *= 4 pos += 4 if len(data) < pos + extension_length: raise ValueError("RTP packet has truncated extension value") extension_value = data[pos : pos + extension_length] pos += extension_length packet.extensions = extensions_map.get(extension_profile, extension_value) if padding: padding_len = data[-1] if not padding_len or padding_len > len(data) - pos: raise ValueError("RTP packet padding length is invalid") packet.padding_size = padding_len packet.payload = data[pos:-padding_len] else: packet.payload = data[pos:] return packet def serialize(self, extensions_map=HeaderExtensionsMap()) -> bytes: extension_profile, extension_value = extensions_map.set(self.extensions) has_extension = bool(extension_value) padding = self.padding_size > 0 data = pack( "!BBHLL", (self.version << 6) | (padding << 5) | (has_extension << 4) | len(self.csrc), (self.marker << 7) | self.payload_type, self.sequence_number, self.timestamp, self.ssrc, ) for csrc in self.csrc: data += pack("!L", csrc) if has_extension: data += pack("!HH", extension_profile, len(extension_value) >> 2) data += extension_value data += self.payload if padding: data += os.urandom(self.padding_size - 1) data += bytes([self.padding_size]) return data def unwrap_rtx(rtx: RtpPacket, payload_type: int, ssrc: int) -> RtpPacket: """ Recover initial packet from a retransmission packet. """ packet = RtpPacket( payload_type=payload_type, marker=rtx.marker, sequence_number=unpack("!H", rtx.payload[0:2])[0], timestamp=rtx.timestamp, ssrc=ssrc, payload=rtx.payload[2:], ) packet.csrc = rtx.csrc packet.extensions = rtx.extensions return packet def wrap_rtx( packet: RtpPacket, payload_type: int, sequence_number: int, ssrc: int ) -> RtpPacket: """ Create a retransmission packet from a lost packet. """ rtx = RtpPacket( payload_type=payload_type, marker=packet.marker, sequence_number=sequence_number, timestamp=packet.timestamp, ssrc=ssrc, payload=pack("!H", packet.sequence_number) + packet.payload, ) rtx.csrc = packet.csrc rtx.extensions = packet.extensions return rtx