import asyncio import enum import hmac import logging import math import os import time from collections import deque from dataclasses import dataclass, field from struct import pack, unpack_from from typing import ( Any, Callable, Deque, Dict, Iterator, List, Optional, Set, Tuple, cast, no_type_check, ) from google_crc32c import value as crc32c from pyee.asyncio import AsyncIOEventEmitter from .exceptions import InvalidStateError from .rtcdatachannel import RTCDataChannel, RTCDataChannelParameters from .rtcdtlstransport import RTCDtlsTransport from .utils import random32, uint16_add, uint16_gt, uint32_gt, uint32_gte logger = logging.getLogger(__name__) # local constants COOKIE_LENGTH = 24 COOKIE_LIFETIME = 60 MAX_STREAMS = 65535 USERDATA_MAX_LENGTH = 1200 # protocol constants SCTP_CAUSE_INVALID_STREAM = 0x0001 SCTP_CAUSE_STALE_COOKIE = 0x0003 SCTP_DATA_LAST_FRAG = 0x01 SCTP_DATA_FIRST_FRAG = 0x02 SCTP_DATA_UNORDERED = 0x04 SCTP_MAX_ASSOCIATION_RETRANS = 10 SCTP_MAX_BURST = 4 SCTP_MAX_INIT_RETRANS = 8 SCTP_RTO_ALPHA = 1 / 8 SCTP_RTO_BETA = 1 / 4 SCTP_RTO_INITIAL = 3.0 SCTP_RTO_MIN = 1 SCTP_RTO_MAX = 60 SCTP_TSN_MODULO = 2**32 RECONFIG_MAX_STREAMS = 135 # parameters SCTP_STATE_COOKIE = 0x0007 SCTP_STR_RESET_OUT_REQUEST = 0x000D SCTP_STR_RESET_RESPONSE = 0x0010 SCTP_STR_RESET_ADD_OUT_STREAMS = 0x0011 SCTP_SUPPORTED_CHUNK_EXT = 0x8008 SCTP_PRSCTP_SUPPORTED = 0xC000 # data channel constants DATA_CHANNEL_ACK = 2 DATA_CHANNEL_OPEN = 3 DATA_CHANNEL_RELIABLE = 0x00 DATA_CHANNEL_PARTIAL_RELIABLE_REXMIT = 0x01 DATA_CHANNEL_PARTIAL_RELIABLE_TIMED = 0x02 DATA_CHANNEL_RELIABLE_UNORDERED = 0x80 DATA_CHANNEL_PARTIAL_RELIABLE_REXMIT_UNORDERED = 0x81 DATA_CHANNEL_PARTIAL_RELIABLE_TIMED_UNORDERED = 0x82 WEBRTC_DCEP = 50 WEBRTC_STRING = 51 WEBRTC_BINARY = 53 WEBRTC_STRING_EMPTY = 56 WEBRTC_BINARY_EMPTY = 57 def chunk_type(chunk) -> str: return chunk.__class__.__name__ def decode_params(body: bytes) -> List[Tuple[int, bytes]]: params = [] pos = 0 while pos <= len(body) - 4: param_type, param_length = unpack_from("!HH", body, pos) params.append((param_type, body[pos + 4 : pos + param_length])) pos += param_length + padl(param_length) return params def encode_params(params: List[Tuple[int, bytes]]) -> bytes: body = b"" padding = b"" for param_type, param_value in params: param_length = len(param_value) + 4 body += padding body += pack("!HH", param_type, param_length) + param_value padding = b"\x00" * padl(param_length) return body def padl(length: int) -> int: m = length % 4 return 4 - m if m else 0 def tsn_minus_one(a: int) -> int: return (a - 1) % SCTP_TSN_MODULO def tsn_plus_one(a: int) -> int: return (a + 1) % SCTP_TSN_MODULO class Chunk: type = -1 def __init__(self, flags: int = 0, body: bytes = b"") -> None: self.flags = flags self.body = body def __bytes__(self) -> bytes: body = self.body data = pack("!BBH", self.type, self.flags, len(body) + 4) + body data += b"\x00" * padl(len(body)) return data def __repr__(self) -> str: return f"{chunk_type(self)}(flags={self.flags})" class BaseParamsChunk(Chunk): def __init__(self, flags: int = 0, body: Optional[bytes] = None) -> None: self.flags = flags if body: self.params = decode_params(body) else: self.params = [] @property def body(self) -> bytes: # type: ignore return encode_params(self.params) class AbortChunk(BaseParamsChunk): type = 6 class CookieAckChunk(Chunk): type = 11 class CookieEchoChunk(Chunk): type = 10 class DataChunk(Chunk): type = 0 def __init__(self, flags: int = 0, body: Optional[bytes] = None) -> None: self.flags = flags if body: (self.tsn, self.stream_id, self.stream_seq, self.protocol) = unpack_from( "!LHHL", body ) self.user_data = body[12:] else: self.tsn = 0 self.stream_id = 0 self.stream_seq = 0 self.protocol = 0 self.user_data = b"" def __bytes__(self) -> bytes: length = 16 + len(self.user_data) data = ( pack( "!BBHLHHL", self.type, self.flags, length, self.tsn, self.stream_id, self.stream_seq, self.protocol, ) + self.user_data ) if length % 4: data += b"\x00" * padl(length) return data def __repr__(self) -> str: return ( f"DataChunk(flags={self.flags}, tsn={self.tsn}, " f"stream_id={self.stream_id}, stream_seq={self.stream_seq})" ) class ErrorChunk(BaseParamsChunk): type = 9 class ForwardTsnChunk(Chunk): type = 192 def __init__(self, flags: int = 0, body: Optional[bytes] = None) -> None: self.flags = flags self.streams: List[Tuple[int, int]] = [] if body: self.cumulative_tsn = unpack_from("!L", body, 0)[0] pos = 4 while pos < len(body): self.streams.append( cast(Tuple[int, int], unpack_from("!HH", body, pos)) ) pos += 4 else: self.cumulative_tsn = 0 @property def body(self) -> bytes: # type: ignore body = pack("!L", self.cumulative_tsn) for stream_id, stream_seq in self.streams: body += pack("!HH", stream_id, stream_seq) return body def __repr__(self) -> str: return ( f"ForwardTsnChunk(cumulative_tsn={self.cumulative_tsn}, " f"streams={self.streams})" ) class HeartbeatChunk(BaseParamsChunk): type = 4 class HeartbeatAckChunk(BaseParamsChunk): type = 5 class BaseInitChunk(Chunk): def __init__(self, flags: int = 0, body: Optional[bytes] = None) -> None: self.flags = flags if body: ( self.initiate_tag, self.advertised_rwnd, self.outbound_streams, self.inbound_streams, self.initial_tsn, ) = unpack_from("!LLHHL", body) self.params = decode_params(body[16:]) else: self.initiate_tag = 0 self.advertised_rwnd = 0 self.outbound_streams = 0 self.inbound_streams = 0 self.initial_tsn = 0 self.params = [] @property def body(self) -> bytes: # type: ignore body = pack( "!LLHHL", self.initiate_tag, self.advertised_rwnd, self.outbound_streams, self.inbound_streams, self.initial_tsn, ) body += encode_params(self.params) return body class InitChunk(BaseInitChunk): type = 1 class InitAckChunk(BaseInitChunk): type = 2 class ReconfigChunk(BaseParamsChunk): type = 130 class SackChunk(Chunk): type = 3 def __init__(self, flags=0, body=None): self.flags = flags self.gaps = [] self.duplicates = [] if body: ( self.cumulative_tsn, self.advertised_rwnd, nb_gaps, nb_duplicates, ) = unpack_from("!LLHH", body) pos = 12 for i in range(nb_gaps): self.gaps.append(unpack_from("!HH", body, pos)) pos += 4 for i in range(nb_duplicates): self.duplicates.append(unpack_from("!L", body, pos)[0]) pos += 4 else: self.cumulative_tsn = 0 self.advertised_rwnd = 0 def __bytes__(self) -> bytes: length = 16 + 4 * (len(self.gaps) + len(self.duplicates)) data = pack( "!BBHLLHH", self.type, self.flags, length, self.cumulative_tsn, self.advertised_rwnd, len(self.gaps), len(self.duplicates), ) for gap in self.gaps: data += pack("!HH", *gap) for tsn in self.duplicates: data += pack("!L", tsn) return data def __repr__(self) -> str: return ( f"SackChunk(flags={self.flags}, advertised_rwnd={self.advertised_rwnd}, " f"cumulative_tsn={self.cumulative_tsn}, gaps={self.gaps})" ) class ShutdownChunk(Chunk): type = 7 def __init__(self, flags=0, body=None): self.flags = flags if body: self.cumulative_tsn = unpack_from("!L", body)[0] else: self.cumulative_tsn = 0 @property def body(self) -> bytes: # type: ignore return pack("!L", self.cumulative_tsn) def __repr__(self) -> str: return ( f"ShutdownChunk(flags={self.flags}, cumulative_tsn={self.cumulative_tsn})" ) class ShutdownAckChunk(Chunk): type = 8 class ShutdownCompleteChunk(Chunk): type = 14 CHUNK_CLASSES = [ DataChunk, InitChunk, InitAckChunk, SackChunk, HeartbeatChunk, HeartbeatAckChunk, AbortChunk, ShutdownChunk, ShutdownAckChunk, ErrorChunk, CookieEchoChunk, CookieAckChunk, ShutdownCompleteChunk, ReconfigChunk, ForwardTsnChunk, ] CHUNK_TYPES = dict((cls.type, cls) for cls in CHUNK_CLASSES) def parse_packet(data: bytes) -> Tuple[int, int, int, List[Any]]: length = len(data) if length < 12: raise ValueError("SCTP packet length is less than 12 bytes") source_port, destination_port, verification_tag = unpack_from("!HHL", data) # verify checksum checksum = unpack_from(" bytes: header = pack("!HHL", source_port, destination_port, verification_tag) data = bytes(chunk) checksum = crc32c(header + b"\x00\x00\x00\x00" + data) return header + pack(" bytes: data = pack( "!LLL", self.request_sequence, self.response_sequence, self.last_tsn ) for stream in self.streams: data += pack("!H", stream) return data @classmethod def parse(cls, data): request_sequence, response_sequence, last_tsn = unpack_from("!LLL", data) streams = [] for pos in range(12, len(data), 2): streams.append(unpack_from("!H", data, pos)[0]) return cls( request_sequence=request_sequence, response_sequence=response_sequence, last_tsn=last_tsn, streams=streams, ) @dataclass class StreamAddOutgoingParam: request_sequence: int new_streams: int def __bytes__(self) -> bytes: data = pack("!LHH", self.request_sequence, self.new_streams, 0) return data @classmethod def parse(cls, data): request_sequence, new_streams, reserved = unpack_from("!LHH", data) return cls(request_sequence=request_sequence, new_streams=new_streams) @dataclass class StreamResetResponseParam: response_sequence: int result: int def __bytes__(self) -> bytes: return pack("!LL", self.response_sequence, self.result) @classmethod def parse(cls, data): response_sequence, result = unpack_from("!LL", data) return cls(response_sequence=response_sequence, result=result) RECONFIG_PARAM_TYPES = { 13: StreamResetOutgoingParam, 16: StreamResetResponseParam, 17: StreamAddOutgoingParam, } class InboundStream: def __init__(self) -> None: self.reassembly: List[DataChunk] = [] self.sequence_number = 0 def add_chunk(self, chunk: DataChunk) -> None: if not self.reassembly or uint32_gt(chunk.tsn, self.reassembly[-1].tsn): self.reassembly.append(chunk) return for i, rchunk in enumerate(self.reassembly): # should never happen, the chunk should have been eliminated # as a duplicate when _mark_received() is called assert rchunk.tsn != chunk.tsn, "duplicate chunk in reassembly" if uint32_gt(rchunk.tsn, chunk.tsn): self.reassembly.insert(i, chunk) break def pop_messages(self) -> Iterator[Tuple[int, int, bytes]]: pos = 0 start_pos = None while pos < len(self.reassembly): chunk = self.reassembly[pos] if start_pos is None: ordered = not (chunk.flags & SCTP_DATA_UNORDERED) if not (chunk.flags & SCTP_DATA_FIRST_FRAG): if ordered: break else: pos += 1 continue if ordered and uint16_gt(chunk.stream_seq, self.sequence_number): break expected_tsn = chunk.tsn start_pos = pos elif chunk.tsn != expected_tsn: if ordered: break else: start_pos = None pos += 1 continue if chunk.flags & SCTP_DATA_LAST_FRAG: user_data = b"".join( [c.user_data for c in self.reassembly[start_pos : pos + 1]] ) self.reassembly = ( self.reassembly[:start_pos] + self.reassembly[pos + 1 :] ) if ordered and chunk.stream_seq == self.sequence_number: self.sequence_number = uint16_add(self.sequence_number, 1) pos = start_pos yield (chunk.stream_id, chunk.protocol, user_data) else: pos += 1 expected_tsn = tsn_plus_one(expected_tsn) def prune_chunks(self, tsn: int) -> int: """ Prune chunks up to the given TSN. """ pos = -1 size = 0 for i, chunk in enumerate(self.reassembly): if uint32_gte(tsn, chunk.tsn): pos = i size += len(chunk.user_data) else: break self.reassembly = self.reassembly[pos + 1 :] return size @dataclass class RTCSctpCapabilities: """ The :class:`RTCSctpCapabilities` dictionary provides information about the capabilities of the :class:`RTCSctpTransport`. """ maxMessageSize: int """ The maximum size of data that the implementation can send or 0 if the implementation can handle messages of any size. """ class RTCSctpTransport(AsyncIOEventEmitter): """ The :class:`RTCSctpTransport` interface includes information relating to Stream Control Transmission Protocol (SCTP) transport. :param transport: An :class:`RTCDtlsTransport`. """ def __init__(self, transport: RTCDtlsTransport, port: int = 5000) -> None: if transport.state == "closed": raise InvalidStateError super().__init__() self._association_state = self.State.CLOSED self.__log_debug: Callable[..., None] = lambda *args: None self.__started = False self.__state = "new" self.__transport = transport self._loop = asyncio.get_event_loop() self._hmac_key = os.urandom(16) self._local_partial_reliability = True self._local_port = port self._local_verification_tag = random32() self._remote_extensions: List[int] = [] self._remote_partial_reliability = False self._remote_port: Optional[int] = None self._remote_verification_tag = 0 # inbound self._advertised_rwnd = 1024 * 1024 self._inbound_streams: Dict[int, InboundStream] = {} self._inbound_streams_count = 0 self._inbound_streams_max = MAX_STREAMS self._last_received_tsn: Optional[int] = None self._sack_duplicates: List[int] = [] self._sack_misordered: Set[int] = set() self._sack_needed = False # outbound self._cwnd = 3 * USERDATA_MAX_LENGTH self._fast_recovery_exit = None self._fast_recovery_transmit = False self._forward_tsn_chunk: Optional[ForwardTsnChunk] = None self._flight_size = 0 self._local_tsn = random32() self._last_sacked_tsn = tsn_minus_one(self._local_tsn) self._advanced_peer_ack_tsn = tsn_minus_one(self._local_tsn) self._outbound_queue: Deque[DataChunk] = deque() self._outbound_stream_seq: Dict[int, int] = {} self._outbound_streams_count = MAX_STREAMS self._partial_bytes_acked = 0 self._sent_queue: Deque[DataChunk] = deque() # reconfiguration self._reconfig_queue: List[int] = [] self._reconfig_request = None self._reconfig_request_seq = self._local_tsn self._reconfig_response_seq = 0 # rtt calculation self._srtt: Optional[float] = None self._rttvar: Optional[float] = None # timers self._rto = SCTP_RTO_INITIAL self._t1_chunk: Optional[Chunk] = None self._t1_failures = 0 self._t1_handle: Optional[asyncio.TimerHandle] = None self._t2_chunk: Optional[Chunk] = None self._t2_failures = 0 self._t2_handle: Optional[asyncio.TimerHandle] = None self._t3_handle: Optional[asyncio.TimerHandle] = None # data channels self._data_channel_id: Optional[int] = None self._data_channel_queue: Deque[Tuple[RTCDataChannel, int, bytes]] = deque() self._data_channels: Dict[int, RTCDataChannel] = {} # FIXME: this is only used by RTCPeerConnection self._bundled = False self.mid: Optional[str] = None @property def is_server(self) -> bool: return self.transport.transport.role != "controlling" @property def maxChannels(self) -> Optional[int]: """ The maximum number of :class:`RTCDataChannel` that can be used simultaneously. """ if self._inbound_streams_count: return min(self._inbound_streams_count, self._outbound_streams_count) return None @property def port(self) -> int: """ The local SCTP port number used for data channels. """ return self._local_port @property def state(self) -> str: """ The current state of the SCTP transport. """ return self.__state @property def transport(self): """ The :class:`RTCDtlsTransport` over which SCTP data is transmitted. """ return self.__transport @classmethod def getCapabilities(cls) -> RTCSctpCapabilities: """ Retrieve the capabilities of the transport. :rtype: RTCSctpCapabilities """ return RTCSctpCapabilities(maxMessageSize=65536) def setTransport(self, transport) -> None: self.__transport = transport async def start(self, remoteCaps: RTCSctpCapabilities, remotePort: int) -> None: """ Start the transport. """ if not self.__started: self.__started = True self.__state = "connecting" self._remote_port = remotePort # configure logging if logger.isEnabledFor(logging.DEBUG): prefix = "RTCSctpTransport(%s) " % ( self.is_server and "server" or "client" ) self.__log_debug = lambda msg, *args: logger.debug(prefix + msg, *args) # initialise local channel ID counter # one side should be using even IDs, the other odd IDs if self.is_server: self._data_channel_id = 0 else: self._data_channel_id = 1 self.__transport._register_data_receiver(self) if not self.is_server: await self._init() async def stop(self) -> None: """ Stop the transport. """ if self._association_state != self.State.CLOSED: await self._abort() self.__transport._unregister_data_receiver(self) self._set_state(self.State.CLOSED) async def _abort(self) -> None: """ Abort the association. """ chunk = AbortChunk() try: await self._send_chunk(chunk) except ConnectionError: pass async def _init(self) -> None: """ Initialize the association. """ chunk = InitChunk() chunk.initiate_tag = self._local_verification_tag chunk.advertised_rwnd = self._advertised_rwnd chunk.outbound_streams = self._outbound_streams_count chunk.inbound_streams = self._inbound_streams_max chunk.initial_tsn = self._local_tsn self._set_extensions(chunk.params) await self._send_chunk(chunk) # start T1 timer and enter COOKIE-WAIT state self._t1_start(chunk) self._set_state(self.State.COOKIE_WAIT) def _flight_size_decrease(self, chunk: DataChunk) -> None: self._flight_size = max(0, self._flight_size - chunk._book_size) # type: ignore def _flight_size_increase(self, chunk: DataChunk) -> None: self._flight_size += chunk._book_size # type: ignore def _get_extensions(self, params: List[Tuple[int, bytes]]) -> None: """ Gets what extensions are supported by the remote party. """ for k, v in params: if k == SCTP_PRSCTP_SUPPORTED: self._remote_partial_reliability = True elif k == SCTP_SUPPORTED_CHUNK_EXT: self._remote_extensions = list(v) def _set_extensions(self, params: List[Tuple[int, bytes]]) -> None: """ Sets what extensions are supported by the local party. """ extensions = [] if self._local_partial_reliability: params.append((SCTP_PRSCTP_SUPPORTED, b"")) extensions.append(ForwardTsnChunk.type) extensions.append(ReconfigChunk.type) params.append((SCTP_SUPPORTED_CHUNK_EXT, bytes(extensions))) def _get_inbound_stream(self, stream_id: int) -> InboundStream: """ Get or create the inbound stream with the specified ID. """ if stream_id not in self._inbound_streams: self._inbound_streams[stream_id] = InboundStream() return self._inbound_streams[stream_id] def _get_timestamp(self) -> int: return int(time.time()) async def _handle_data(self, data): """ Handle data received from the network. """ try: _, _, verification_tag, chunks = parse_packet(data) except ValueError: return # is this an init? init_chunk = len([x for x in chunks if isinstance(x, InitChunk)]) if init_chunk: assert len(chunks) == 1 expected_tag = 0 else: expected_tag = self._local_verification_tag # verify tag if verification_tag != expected_tag: self.__log_debug( "Bad verification tag %d vs %d", verification_tag, expected_tag ) return # handle chunks for chunk in chunks: await self._receive_chunk(chunk) # send SACK if needed if self._sack_needed: await self._send_sack() @no_type_check def _maybe_abandon(self, chunk: DataChunk) -> bool: """ Determine if a chunk needs to be marked as abandoned. If it does, it marks the chunk and any other chunk belong to the same message as abandoned. """ if chunk._abandoned: return True abandon = ( chunk._max_retransmits is not None and chunk._sent_count > chunk._max_retransmits ) or (chunk._expiry is not None and chunk._expiry < time.time()) if not abandon: return False chunk_pos = self._sent_queue.index(chunk) for pos in range(chunk_pos, -1, -1): ochunk = self._sent_queue[pos] ochunk._abandoned = True ochunk._retransmit = False if ochunk.flags & SCTP_DATA_FIRST_FRAG: break for pos in range(chunk_pos, len(self._sent_queue)): ochunk = self._sent_queue[pos] ochunk._abandoned = True ochunk._retransmit = False if ochunk.flags & SCTP_DATA_LAST_FRAG: break return True def _mark_received(self, tsn: int) -> bool: """ Mark an incoming data TSN as received. """ # it's a duplicate if uint32_gte(self._last_received_tsn, tsn) or tsn in self._sack_misordered: self._sack_duplicates.append(tsn) return True # consolidate misordered entries self._sack_misordered.add(tsn) for tsn in sorted(self._sack_misordered): if tsn == tsn_plus_one(self._last_received_tsn): self._last_received_tsn = tsn else: break # filter out obsolete entries def is_obsolete(x): return uint32_gt(x, self._last_received_tsn) self._sack_duplicates = list(filter(is_obsolete, self._sack_duplicates)) self._sack_misordered = set(filter(is_obsolete, self._sack_misordered)) return False async def _receive(self, stream_id: int, pp_id: int, data: bytes) -> None: """ Receive data stream -> ULP. """ await self._data_channel_receive(stream_id, pp_id, data) async def _receive_chunk(self, chunk): """ Handle an incoming chunk. """ self.__log_debug("< %s", chunk) # common if isinstance(chunk, DataChunk): await self._receive_data_chunk(chunk) elif isinstance(chunk, SackChunk): await self._receive_sack_chunk(chunk) elif isinstance(chunk, ForwardTsnChunk): await self._receive_forward_tsn_chunk(chunk) elif isinstance(chunk, HeartbeatChunk): ack = HeartbeatAckChunk() ack.params = chunk.params await self._send_chunk(ack) elif isinstance(chunk, AbortChunk): self.__log_debug("x Association was aborted by remote party") self._set_state(self.State.CLOSED) elif isinstance(chunk, ShutdownChunk): self._t2_cancel() self._set_state(self.State.SHUTDOWN_RECEIVED) ack = ShutdownAckChunk() await self._send_chunk(ack) self._t2_start(ack) self._set_state(self.State.SHUTDOWN_ACK_SENT) elif ( isinstance(chunk, ShutdownCompleteChunk) and self._association_state == self.State.SHUTDOWN_ACK_SENT ): self._t2_cancel() self._set_state(self.State.CLOSED) elif ( isinstance(chunk, ReconfigChunk) and self._association_state == self.State.ESTABLISHED ): for param in chunk.params: cls = RECONFIG_PARAM_TYPES.get(param[0]) if cls: await self._receive_reconfig_param(cls.parse(param[1])) # server elif isinstance(chunk, InitChunk) and self.is_server: self._last_received_tsn = tsn_minus_one(chunk.initial_tsn) self._reconfig_response_seq = tsn_minus_one(chunk.initial_tsn) self._remote_verification_tag = chunk.initiate_tag self._ssthresh = chunk.advertised_rwnd self._get_extensions(chunk.params) self.__log_debug( "- Peer supports %d outbound streams, %d max inbound streams", chunk.outbound_streams, chunk.inbound_streams, ) self._inbound_streams_count = min( chunk.outbound_streams, self._inbound_streams_max ) self._outbound_streams_count = min( self._outbound_streams_count, chunk.inbound_streams ) ack = InitAckChunk() ack.initiate_tag = self._local_verification_tag ack.advertised_rwnd = self._advertised_rwnd ack.outbound_streams = self._outbound_streams_count ack.inbound_streams = self._inbound_streams_max ack.initial_tsn = self._local_tsn self._set_extensions(ack.params) # generate state cookie cookie = pack("!L", self._get_timestamp()) cookie += hmac.new(self._hmac_key, cookie, "sha1").digest() ack.params.append((SCTP_STATE_COOKIE, cookie)) await self._send_chunk(ack) elif isinstance(chunk, CookieEchoChunk) and self.is_server: # check state cookie MAC cookie = chunk.body if ( len(cookie) != COOKIE_LENGTH or hmac.new(self._hmac_key, cookie[0:4], "sha1").digest() != cookie[4:] ): self.__log_debug("x State cookie is invalid") return # check state cookie lifetime now = self._get_timestamp() stamp = unpack_from("!L", cookie)[0] if stamp < now - COOKIE_LIFETIME or stamp > now: self.__log_debug("x State cookie has expired") error = ErrorChunk() error.params.append((SCTP_CAUSE_STALE_COOKIE, b"\x00" * 8)) await self._send_chunk(error) return ack = CookieAckChunk() await self._send_chunk(ack) self._set_state(self.State.ESTABLISHED) # client elif ( isinstance(chunk, InitAckChunk) and self._association_state == self.State.COOKIE_WAIT ): # cancel T1 timer and process chunk self._t1_cancel() self._last_received_tsn = tsn_minus_one(chunk.initial_tsn) self._reconfig_response_seq = tsn_minus_one(chunk.initial_tsn) self._remote_verification_tag = chunk.initiate_tag self._ssthresh = chunk.advertised_rwnd self._get_extensions(chunk.params) self.__log_debug( "- Peer supports %d outbound streams, %d max inbound streams", chunk.outbound_streams, chunk.inbound_streams, ) self._inbound_streams_count = min( chunk.outbound_streams, self._inbound_streams_max ) self._outbound_streams_count = min( self._outbound_streams_count, chunk.inbound_streams ) echo = CookieEchoChunk() for k, v in chunk.params: if k == SCTP_STATE_COOKIE: echo.body = v break await self._send_chunk(echo) # start T1 timer and enter COOKIE-ECHOED state self._t1_start(echo) self._set_state(self.State.COOKIE_ECHOED) elif ( isinstance(chunk, CookieAckChunk) and self._association_state == self.State.COOKIE_ECHOED ): # cancel T1 timer and enter ESTABLISHED state self._t1_cancel() self._set_state(self.State.ESTABLISHED) elif isinstance(chunk, ErrorChunk) and self._association_state in [ self.State.COOKIE_WAIT, self.State.COOKIE_ECHOED, ]: self._t1_cancel() self._set_state(self.State.CLOSED) self.__log_debug("x Could not establish association") return async def _receive_data_chunk(self, chunk: DataChunk) -> None: """ Handle a DATA chunk. """ self._sack_needed = True # mark as received if self._mark_received(chunk.tsn): return # find stream inbound_stream = self._get_inbound_stream(chunk.stream_id) # defragment data inbound_stream.add_chunk(chunk) self._advertised_rwnd -= len(chunk.user_data) for message in inbound_stream.pop_messages(): self._advertised_rwnd += len(message[2]) await self._receive(*message) async def _receive_forward_tsn_chunk(self, chunk: ForwardTsnChunk) -> None: """ Handle a FORWARD TSN chunk. """ self._sack_needed = True # it's a duplicate if uint32_gte(self._last_received_tsn, chunk.cumulative_tsn): return def is_obsolete(x): return uint32_gt(x, self._last_received_tsn) # advance cumulative TSN self._last_received_tsn = chunk.cumulative_tsn self._sack_misordered = set(filter(is_obsolete, self._sack_misordered)) for tsn in sorted(self._sack_misordered): if tsn == tsn_plus_one(self._last_received_tsn): self._last_received_tsn = tsn else: break # filter out obsolete entries self._sack_duplicates = list(filter(is_obsolete, self._sack_duplicates)) self._sack_misordered = set(filter(is_obsolete, self._sack_misordered)) # update reassembly for stream_id, stream_seq in chunk.streams: inbound_stream = self._get_inbound_stream(stream_id) # advance sequence number and perform delivery inbound_stream.sequence_number = uint16_add(stream_seq, 1) for message in inbound_stream.pop_messages(): self._advertised_rwnd += len(message[2]) await self._receive(*message) # prune obsolete chunks for stream_id, inbound_stream in self._inbound_streams.items(): self._advertised_rwnd += inbound_stream.prune_chunks( self._last_received_tsn ) @no_type_check async def _receive_sack_chunk(self, chunk: SackChunk) -> None: """ Handle a SACK chunk. """ if uint32_gt(self._last_sacked_tsn, chunk.cumulative_tsn): return received_time = time.time() self._last_sacked_tsn = chunk.cumulative_tsn cwnd_fully_utilized = self._flight_size >= self._cwnd done = 0 done_bytes = 0 # handle acknowledged data while self._sent_queue and uint32_gte( self._last_sacked_tsn, self._sent_queue[0].tsn ): schunk = self._sent_queue.popleft() done += 1 if not schunk._acked: done_bytes += schunk._book_size self._flight_size_decrease(schunk) # update RTO estimate if done == 1 and schunk._sent_count == 1: self._update_rto(received_time - schunk._sent_time) # handle gap blocks loss = False if chunk.gaps: seen = set() for gap in chunk.gaps: for pos in range(gap[0], gap[1] + 1): highest_seen_tsn = (chunk.cumulative_tsn + pos) % SCTP_TSN_MODULO seen.add(highest_seen_tsn) # determined Highest TSN Newly Acked (HTNA) highest_newly_acked = chunk.cumulative_tsn for schunk in self._sent_queue: if uint32_gt(schunk.tsn, highest_seen_tsn): break if schunk.tsn in seen and not schunk._acked: done_bytes += schunk._book_size schunk._acked = True self._flight_size_decrease(schunk) highest_newly_acked = schunk.tsn # strike missing chunks prior to HTNA for schunk in self._sent_queue: if uint32_gt(schunk.tsn, highest_newly_acked): break if schunk.tsn not in seen: schunk._misses += 1 if schunk._misses == 3: schunk._misses = 0 if not self._maybe_abandon(schunk): schunk._retransmit = True schunk._acked = False self._flight_size_decrease(schunk) loss = True # adjust congestion window if self._fast_recovery_exit is None: if done and cwnd_fully_utilized: if self._cwnd <= self._ssthresh: # slow start self._cwnd += min(done_bytes, USERDATA_MAX_LENGTH) else: # congestion avoidance self._partial_bytes_acked += done_bytes if self._partial_bytes_acked >= self._cwnd: self._partial_bytes_acked -= self._cwnd self._cwnd += USERDATA_MAX_LENGTH if loss: self._ssthresh = max(self._cwnd // 2, 4 * USERDATA_MAX_LENGTH) self._cwnd = self._ssthresh self._partial_bytes_acked = 0 self._fast_recovery_exit = self._sent_queue[-1].tsn self._fast_recovery_transmit = True elif uint32_gte(chunk.cumulative_tsn, self._fast_recovery_exit): self._fast_recovery_exit = None if not self._sent_queue: # there is no outstanding data, stop T3 self._t3_cancel() elif done: # the earliest outstanding chunk was acknowledged, restart T3 self._t3_restart() self._update_advanced_peer_ack_point() await self._data_channel_flush() await self._transmit() async def _receive_reconfig_param(self, param): """ Handle a RE-CONFIG parameter. """ self.__log_debug("<< %s", param) if isinstance(param, StreamResetOutgoingParam): # mark closed inbound streams for stream_id in param.streams: self._inbound_streams.pop(stream_id, None) # close data channel channel = self._data_channels.get(stream_id) if channel: self._data_channel_close(channel) # send response response_param = StreamResetResponseParam( response_sequence=param.request_sequence, result=1 ) self._reconfig_response_seq = param.request_sequence await self._send_reconfig_param(response_param) elif isinstance(param, StreamAddOutgoingParam): # increase inbound streams self._inbound_streams_count += param.new_streams # send response response_param = StreamResetResponseParam( response_sequence=param.request_sequence, result=1 ) self._reconfig_response_seq = param.request_sequence await self._send_reconfig_param(response_param) elif isinstance(param, StreamResetResponseParam): if ( self._reconfig_request and param.response_sequence == self._reconfig_request.request_sequence ): # mark closed streams for stream_id in self._reconfig_request.streams: self._outbound_stream_seq.pop(stream_id, None) self._data_channel_closed(stream_id) self._reconfig_request = None await self._transmit_reconfig() @no_type_check async def _send( self, stream_id: int, pp_id: int, user_data: bytes, expiry: Optional[float] = None, max_retransmits: Optional[int] = None, ordered: bool = True, ) -> None: """ Send data ULP -> stream. """ if ordered: stream_seq = self._outbound_stream_seq.get(stream_id, 0) else: stream_seq = 0 fragments = math.ceil(len(user_data) / USERDATA_MAX_LENGTH) pos = 0 for fragment in range(0, fragments): chunk = DataChunk() chunk.flags = 0 if not ordered: chunk.flags = SCTP_DATA_UNORDERED if fragment == 0: chunk.flags |= SCTP_DATA_FIRST_FRAG if fragment == fragments - 1: chunk.flags |= SCTP_DATA_LAST_FRAG chunk.tsn = self._local_tsn chunk.stream_id = stream_id chunk.stream_seq = stream_seq chunk.protocol = pp_id chunk.user_data = user_data[pos : pos + USERDATA_MAX_LENGTH] # FIXME: dynamically added attributes, mypy can't handle them # initialize counters chunk._abandoned = False chunk._acked = False chunk._book_size = len(chunk.user_data) chunk._expiry = expiry chunk._max_retransmits = max_retransmits chunk._misses = 0 chunk._retransmit = False chunk._sent_count = 0 chunk._sent_time = None pos += USERDATA_MAX_LENGTH self._local_tsn = tsn_plus_one(self._local_tsn) self._outbound_queue.append(chunk) if ordered: self._outbound_stream_seq[stream_id] = uint16_add(stream_seq, 1) # transmit outbound data await self._transmit() async def _send_chunk(self, chunk: Chunk) -> None: """ Transmit a chunk (no bundling for now). """ self.__log_debug("> %s", chunk) await self.__transport._send_data( serialize_packet( self._local_port, self._remote_port, self._remote_verification_tag, chunk, ) ) async def _send_reconfig_param(self, param): chunk = ReconfigChunk() for k, cls in RECONFIG_PARAM_TYPES.items(): if isinstance(param, cls): param_type = k break chunk.params.append((param_type, bytes(param))) self.__log_debug(">> %s", param) await self._send_chunk(chunk) async def _send_sack(self): """ Build and send a selective acknowledgement (SACK) chunk. """ gaps = [] gap_next = None for tsn in sorted(self._sack_misordered): pos = (tsn - self._last_received_tsn) % SCTP_TSN_MODULO if tsn == gap_next: gaps[-1][1] = pos else: gaps.append([pos, pos]) gap_next = tsn_plus_one(tsn) sack = SackChunk() sack.cumulative_tsn = self._last_received_tsn sack.advertised_rwnd = max(0, self._advertised_rwnd) sack.duplicates = self._sack_duplicates[:] sack.gaps = [tuple(x) for x in gaps] await self._send_chunk(sack) self._sack_duplicates.clear() self._sack_needed = False def _set_state(self, state) -> None: """ Transition the SCTP association to a new state. """ if state != self._association_state: self.__log_debug("- %s -> %s", self._association_state, state) self._association_state = state if state == self.State.ESTABLISHED: self.__state = "connected" for channel in list(self._data_channels.values()): if channel.negotiated and channel.readyState != "open": channel._setReadyState("open") asyncio.ensure_future(self._data_channel_flush()) elif state == self.State.CLOSED: self._t1_cancel() self._t2_cancel() self._t3_cancel() self.__state = "closed" # close data channels for stream_id in list(self._data_channels.keys()): self._data_channel_closed(stream_id) # no more events will be emitted, so remove all event listeners # to facilitate garbage collection. self.remove_all_listeners() # timers def _t1_cancel(self) -> None: if self._t1_handle is not None: self.__log_debug("- T1(%s) cancel", chunk_type(self._t1_chunk)) self._t1_handle.cancel() self._t1_handle = None self._t1_chunk = None def _t1_expired(self) -> None: self._t1_failures += 1 self._t1_handle = None self.__log_debug( "x T1(%s) expired %d", chunk_type(self._t1_chunk), self._t1_failures ) if self._t1_failures > SCTP_MAX_INIT_RETRANS: self._set_state(self.State.CLOSED) else: asyncio.ensure_future(self._send_chunk(self._t1_chunk)) self._t1_handle = self._loop.call_later(self._rto, self._t1_expired) def _t1_start(self, chunk: Chunk) -> None: assert self._t1_handle is None self._t1_chunk = chunk self._t1_failures = 0 self.__log_debug("- T1(%s) start", chunk_type(self._t1_chunk)) self._t1_handle = self._loop.call_later(self._rto, self._t1_expired) def _t2_cancel(self) -> None: if self._t2_handle is not None: self.__log_debug("- T2(%s) cancel", chunk_type(self._t2_chunk)) self._t2_handle.cancel() self._t2_handle = None self._t2_chunk = None def _t2_expired(self) -> None: self._t2_failures += 1 self._t2_handle = None self.__log_debug( "x T2(%s) expired %d", chunk_type(self._t2_chunk), self._t2_failures ) if self._t2_failures > SCTP_MAX_ASSOCIATION_RETRANS: self._set_state(self.State.CLOSED) else: asyncio.ensure_future(self._send_chunk(self._t2_chunk)) self._t2_handle = self._loop.call_later(self._rto, self._t2_expired) def _t2_start(self, chunk) -> None: assert self._t2_handle is None self._t2_chunk = chunk self._t2_failures = 0 self.__log_debug("- T2(%s) start", chunk_type(self._t2_chunk)) self._t2_handle = self._loop.call_later(self._rto, self._t2_expired) @no_type_check def _t3_expired(self) -> None: self._t3_handle = None self.__log_debug("x T3 expired") # mark retransmit or abandoned chunks for chunk in self._sent_queue: if not self._maybe_abandon(chunk): chunk._retransmit = True self._update_advanced_peer_ack_point() # adjust congestion window self._fast_recovery_exit = None self._flight_size = 0 self._partial_bytes_acked = 0 self._ssthresh = max(self._cwnd // 2, 4 * USERDATA_MAX_LENGTH) self._cwnd = USERDATA_MAX_LENGTH asyncio.ensure_future(self._transmit()) def _t3_restart(self) -> None: self.__log_debug("- T3 restart") if self._t3_handle is not None: self._t3_handle.cancel() self._t3_handle = None self._t3_handle = self._loop.call_later(self._rto, self._t3_expired) def _t3_start(self) -> None: assert self._t3_handle is None self.__log_debug("- T3 start") self._t3_handle = self._loop.call_later(self._rto, self._t3_expired) def _t3_cancel(self) -> None: if self._t3_handle is not None: self.__log_debug("- T3 cancel") self._t3_handle.cancel() self._t3_handle = None @no_type_check async def _transmit(self) -> None: """ Transmit outbound data. """ # send FORWARD TSN if self._forward_tsn_chunk is not None: await self._send_chunk(self._forward_tsn_chunk) self._forward_tsn_chunk = None # ensure T3 is running if not self._t3_handle: self._t3_start() # limit burst size if self._fast_recovery_exit is not None: burst_size = 2 * USERDATA_MAX_LENGTH else: burst_size = 4 * USERDATA_MAX_LENGTH cwnd = min(self._flight_size + burst_size, self._cwnd) # retransmit retransmit_earliest = True for chunk in self._sent_queue: if chunk._retransmit: if self._fast_recovery_transmit: self._fast_recovery_transmit = False elif self._flight_size >= cwnd: return self._flight_size_increase(chunk) chunk._misses = 0 chunk._retransmit = False chunk._sent_count += 1 await self._send_chunk(chunk) if retransmit_earliest: # restart the T3 timer as the earliest outstanding TSN # is being retransmitted self._t3_restart() retransmit_earliest = False while self._outbound_queue and self._flight_size < cwnd: chunk = self._outbound_queue.popleft() self._sent_queue.append(chunk) self._flight_size_increase(chunk) # update counters chunk._sent_count += 1 chunk._sent_time = time.time() await self._send_chunk(chunk) if not self._t3_handle: self._t3_start() async def _transmit_reconfig(self): if ( self._association_state == self.State.ESTABLISHED and self._reconfig_queue and not self._reconfig_request ): streams = self._reconfig_queue[0:RECONFIG_MAX_STREAMS] self._reconfig_queue = self._reconfig_queue[RECONFIG_MAX_STREAMS:] param = StreamResetOutgoingParam( request_sequence=self._reconfig_request_seq, response_sequence=self._reconfig_response_seq, last_tsn=tsn_minus_one(self._local_tsn), streams=streams, ) self._reconfig_request = param self._reconfig_request_seq = tsn_plus_one(self._reconfig_request_seq) await self._send_reconfig_param(param) @no_type_check def _update_advanced_peer_ack_point(self) -> None: """ Try to advance "Advanced.Peer.Ack.Point" according to RFC 3758. """ if uint32_gt(self._last_sacked_tsn, self._advanced_peer_ack_tsn): self._advanced_peer_ack_tsn = self._last_sacked_tsn done = 0 streams = {} while self._sent_queue and self._sent_queue[0]._abandoned: chunk = self._sent_queue.popleft() self._advanced_peer_ack_tsn = chunk.tsn done += 1 if not (chunk.flags & SCTP_DATA_UNORDERED): streams[chunk.stream_id] = chunk.stream_seq if done: # build FORWARD TSN self._forward_tsn_chunk = ForwardTsnChunk() self._forward_tsn_chunk.cumulative_tsn = self._advanced_peer_ack_tsn self._forward_tsn_chunk.streams = list(streams.items()) def _update_rto(self, R: float) -> None: """ Update RTO given a new roundtrip measurement R. """ if self._srtt is None: self._rttvar = R / 2 self._srtt = R else: self._rttvar = (1 - SCTP_RTO_BETA) * self._rttvar + SCTP_RTO_BETA * abs( self._srtt - R ) self._srtt = (1 - SCTP_RTO_ALPHA) * self._srtt + SCTP_RTO_ALPHA * R self._rto = max(SCTP_RTO_MIN, min(self._srtt + 4 * self._rttvar, SCTP_RTO_MAX)) def _data_channel_close(self, channel, transmit=True): """ Request closing the datachannel by sending an Outgoing Stream Reset Request. """ if channel.readyState not in ["closing", "closed"]: channel._setReadyState("closing") if self._association_state == self.State.ESTABLISHED: # queue a stream reset self._reconfig_queue.append(channel.id) if len(self._reconfig_queue) == 1: asyncio.ensure_future(self._transmit_reconfig()) else: # remove any queued messages for the datachannel new_queue = deque() for queue_item in self._data_channel_queue: if queue_item[0] != channel: new_queue.append(queue_item) self._data_channel_queue = new_queue # mark the datachannel as closed if channel.id is not None: self._data_channels.pop(channel.id) channel._setReadyState("closed") def _data_channel_closed(self, stream_id: int) -> None: channel = self._data_channels.pop(stream_id) channel._setReadyState("closed") async def _data_channel_flush(self) -> None: """ Try to flush buffered data to the SCTP layer. We wait until the association is established, as we need to know whether we are a client or a server to correctly assign an odd/even ID to the data channels. """ if self._association_state != self.State.ESTABLISHED: return while self._data_channel_queue and not self._outbound_queue: channel, protocol, user_data = self._data_channel_queue.popleft() # register channel if necessary stream_id = channel.id if stream_id is None: stream_id = self._data_channel_id while stream_id in self._data_channels: stream_id += 2 self._data_channels[stream_id] = channel channel._setId(stream_id) # send data if protocol == WEBRTC_DCEP: await self._send(stream_id, protocol, user_data) else: if channel.maxPacketLifeTime: expiry = time.time() + (channel.maxPacketLifeTime / 1000) else: expiry = None await self._send( stream_id, protocol, user_data, expiry=expiry, max_retransmits=channel.maxRetransmits, ordered=channel.ordered, ) channel._addBufferedAmount(-len(user_data)) def _data_channel_add_negotiated(self, channel: RTCDataChannel) -> None: if channel.id in self._data_channels: raise ValueError(f"Data channel with ID {channel.id} already registered") self._data_channels[channel.id] = channel if self._association_state == self.State.ESTABLISHED: channel._setReadyState("open") def _data_channel_open(self, channel: RTCDataChannel) -> None: if channel.id is not None: if channel.id in self._data_channels: raise ValueError( f"Data channel with ID {channel.id} already registered" ) else: self._data_channels[channel.id] = channel channel_type = DATA_CHANNEL_RELIABLE priority = 0 reliability = 0 if not channel.ordered: channel_type |= 0x80 if channel.maxRetransmits is not None: channel_type |= 1 reliability = channel.maxRetransmits elif channel.maxPacketLifeTime is not None: channel_type |= 2 reliability = channel.maxPacketLifeTime data = pack( "!BBHLHH", DATA_CHANNEL_OPEN, channel_type, priority, reliability, len(channel.label), len(channel.protocol), ) data += channel.label.encode("utf8") data += channel.protocol.encode("utf8") self._data_channel_queue.append((channel, WEBRTC_DCEP, data)) asyncio.ensure_future(self._data_channel_flush()) async def _data_channel_receive( self, stream_id: int, pp_id: int, data: bytes ) -> None: if pp_id == WEBRTC_DCEP and len(data): msg_type = data[0] if msg_type == DATA_CHANNEL_OPEN and len(data) >= 12: # we should not receive an open for an existing channel assert stream_id not in self._data_channels ( msg_type, channel_type, priority, reliability, label_length, protocol_length, ) = unpack_from("!BBHLHH", data) pos = 12 label = data[pos : pos + label_length].decode("utf8") pos += label_length protocol = data[pos : pos + protocol_length].decode("utf8") # check channel type maxPacketLifeTime = None maxRetransmits = None if (channel_type & 0x03) == 1: maxRetransmits = reliability elif (channel_type & 0x03) == 2: maxPacketLifeTime = reliability # register channel parameters = RTCDataChannelParameters( label=label, ordered=(channel_type & 0x80) == 0, maxPacketLifeTime=maxPacketLifeTime, maxRetransmits=maxRetransmits, protocol=protocol, id=stream_id, ) channel = RTCDataChannel(self, parameters, False) channel._setReadyState("open") self._data_channels[stream_id] = channel # send ack self._data_channel_queue.append( (channel, WEBRTC_DCEP, pack("!B", DATA_CHANNEL_ACK)) ) await self._data_channel_flush() # emit channel self.emit("datachannel", channel) elif msg_type == DATA_CHANNEL_ACK: assert stream_id in self._data_channels channel = self._data_channels[stream_id] channel._setReadyState("open") elif pp_id == WEBRTC_STRING and stream_id in self._data_channels: # emit message self._data_channels[stream_id].emit("message", data.decode("utf8")) elif pp_id == WEBRTC_STRING_EMPTY and stream_id in self._data_channels: # emit message self._data_channels[stream_id].emit("message", "") elif pp_id == WEBRTC_BINARY and stream_id in self._data_channels: # emit message self._data_channels[stream_id].emit("message", data) elif pp_id == WEBRTC_BINARY_EMPTY and stream_id in self._data_channels: # emit message self._data_channels[stream_id].emit("message", b"") def _data_channel_send(self, channel: RTCDataChannel, data: bytes) -> None: if data == "": pp_id, user_data = WEBRTC_STRING_EMPTY, b"\x00" elif isinstance(data, str): pp_id, user_data = WEBRTC_STRING, data.encode("utf8") elif data == b"": pp_id, user_data = WEBRTC_BINARY_EMPTY, b"\x00" else: pp_id, user_data = WEBRTC_BINARY, data channel._addBufferedAmount(len(user_data)) self._data_channel_queue.append((channel, pp_id, user_data)) asyncio.ensure_future(self._data_channel_flush()) class State(enum.Enum): CLOSED = 1 COOKIE_WAIT = 2 COOKIE_ECHOED = 3 ESTABLISHED = 4 SHUTDOWN_PENDING = 5 SHUTDOWN_SENT = 6 SHUTDOWN_RECEIVED = 7 SHUTDOWN_ACK_SENT = 8