import asyncio import hashlib import logging import socket import struct import time from typing import ( Any, Callable, Dict, List, Optional, Text, Tuple, TypeVar, Union, cast, ) from . import stun from .utils import random_transaction_id logger = logging.getLogger(__name__) DEFAULT_CHANNEL_REFRESH_TIME = 500 DEFAULT_ALLOCATION_LIFETIME = 600 TCP_TRANSPORT = 0x06000000 UDP_TRANSPORT = 0x11000000 UDP_SOCKET_BUFFER_SIZE = 262144 _ProtocolT = TypeVar("_ProtocolT", bound=asyncio.BaseProtocol) def is_channel_data(data: bytes) -> bool: return (data[0] & 0xC0) == 0x40 def make_integrity_key(username: str, realm: str, password: str) -> bytes: return hashlib.md5(":".join([username, realm, password]).encode("utf8")).digest() class TurnStreamMixin: datagram_received: Callable transport: asyncio.BaseTransport def data_received(self, data: bytes) -> None: if not hasattr(self, "buffer"): self.buffer = b"" self.buffer += data while len(self.buffer) >= 4: _, length = struct.unpack("!HH", self.buffer[0:4]) length += stun.padding_length(length) if is_channel_data(self.buffer): full_length = 4 + length else: full_length = 20 + length if len(self.buffer) < full_length: break addr = self.transport.get_extra_info("peername") self.datagram_received(self.buffer[0:full_length], addr) self.buffer = self.buffer[full_length:] def _padded(self, data: bytes) -> bytes: # TCP and TCP-over-TLS must pad messages to 4-byte boundaries. padding = stun.padding_length(len(data)) if padding: data += bytes(padding) return data class TurnClientMixin: _send: Callable def __init__( self, server: Tuple[str, int], username: Optional[str], password: Optional[str], lifetime: int, channel_refresh_time: int, ) -> None: self.channel_refresh_at: Dict[int, float] = {} self.channel_to_peer: Dict[int, Tuple[str, int]] = {} self.peer_connect_waiters: Dict[ Tuple[str, int], List[asyncio.Future[None]] ] = {} self.peer_to_channel: Dict[Tuple[str, int], int] = {} self.channel_number = 0x4000 self.channel_refresh_time = channel_refresh_time self.integrity_key: Optional[bytes] = None self.lifetime = lifetime self.nonce: Optional[bytes] = None self.password = password self.receiver = None self.realm: Optional[str] = None self.refresh_task: Optional[asyncio.Task] = None self.relayed_address: Optional[Tuple[str, int]] = None self.server = server self.transactions: Dict[bytes, stun.Transaction] = {} self.username = username async def channel_bind(self, channel_number: int, addr: Tuple[str, int]) -> None: request = stun.Message( message_method=stun.Method.CHANNEL_BIND, message_class=stun.Class.REQUEST ) request.attributes["CHANNEL-NUMBER"] = channel_number request.attributes["XOR-PEER-ADDRESS"] = addr await self.request_with_retry(request) logger.info("TURN channel bound %d %s", channel_number, addr) async def connect(self) -> Tuple[str, int]: """ Create a TURN allocation. """ request = stun.Message( message_method=stun.Method.ALLOCATE, message_class=stun.Class.REQUEST ) request.attributes["LIFETIME"] = self.lifetime request.attributes["REQUESTED-TRANSPORT"] = UDP_TRANSPORT response, _ = await self.request_with_retry(request) time_to_expiry = response.attributes["LIFETIME"] self.relayed_address = response.attributes["XOR-RELAYED-ADDRESS"] logger.info( "TURN allocation created %s (expires in %d seconds)", self.relayed_address, time_to_expiry, ) # periodically refresh allocation self.refresh_task = asyncio.create_task(self.refresh(time_to_expiry)) return self.relayed_address def connection_lost(self, exc: Exception) -> None: logger.debug("%s connection_lost(%s)", self, exc) if self.receiver: self.receiver.connection_lost(exc) def connection_made(self, transport) -> None: logger.debug("%s connection_made(%s)", self, transport) self.transport = transport def datagram_received(self, data: Union[bytes, Text], addr) -> None: data = cast(bytes, data) # demultiplex channel data if len(data) >= 4 and is_channel_data(data): channel, length = struct.unpack("!HH", data[0:4]) if len(data) >= length + 4 and self.receiver: peer_address = self.channel_to_peer.get(channel) if peer_address: payload = data[4 : 4 + length] self.receiver.datagram_received(payload, peer_address) return try: message = stun.parse_message(data) logger.debug("%s < %s %s", self, addr, message) except ValueError: return if ( message.message_class == stun.Class.RESPONSE or message.message_class == stun.Class.ERROR ) and message.transaction_id in self.transactions: transaction = self.transactions[message.transaction_id] transaction.response_received(message, addr) async def delete(self) -> None: """ Delete the TURN allocation. """ if self.refresh_task: self.refresh_task.cancel() self.refresh_task = None request = stun.Message( message_method=stun.Method.REFRESH, message_class=stun.Class.REQUEST ) request.attributes["LIFETIME"] = 0 try: await self.request_with_retry(request) except stun.TransactionError: # we do not care, we need to shutdown pass logger.info("TURN allocation deleted %s", self.relayed_address) self.transport.close() async def refresh(self, time_to_expiry) -> None: """ Periodically refresh the TURN allocation. """ while True: await asyncio.sleep(5 / 6 * time_to_expiry) request = stun.Message( message_method=stun.Method.REFRESH, message_class=stun.Class.REQUEST ) request.attributes["LIFETIME"] = self.lifetime response, _ = await self.request_with_retry(request) time_to_expiry = response.attributes["LIFETIME"] logger.info( "TURN allocation refreshed %s (expires in %d seconds)", self.relayed_address, time_to_expiry, ) async def request( self, request: stun.Message ) -> Tuple[stun.Message, Tuple[str, int]]: """ Execute a STUN transaction and return the response. """ assert request.transaction_id not in self.transactions if self.integrity_key: self.__add_authentication(request) transaction = stun.Transaction(request, self.server, self) self.transactions[request.transaction_id] = transaction try: return await transaction.run() finally: del self.transactions[request.transaction_id] async def request_with_retry( self, request: stun.Message ) -> Tuple[stun.Message, Tuple[str, int]]: """ Execute a STUN transaction and return the response. On recoverable errors it will retry the request. """ try: response, addr = await self.request(request) except stun.TransactionFailed as e: error_code = e.response.attributes["ERROR-CODE"][0] if ( "NONCE" in e.response.attributes and self.username is not None and self.password is not None and ( (error_code == 401 and "REALM" in e.response.attributes) or (error_code == 438 and self.realm is not None) ) ): # update long-term credentials self.nonce = e.response.attributes["NONCE"] if error_code == 401: self.realm = e.response.attributes["REALM"] self.integrity_key = make_integrity_key( self.username, self.realm, self.password ) # retry request with authentication request.transaction_id = random_transaction_id() response, addr = await self.request(request) else: raise return response, addr async def send_data(self, data: bytes, addr: Tuple[str, int]) -> None: """ Send data to a remote host via the TURN server. """ # if a channel is being bound for the peer, wait if addr in self.peer_connect_waiters: loop = asyncio.get_event_loop() waiter = loop.create_future() self.peer_connect_waiters[addr].append(waiter) await waiter channel = self.peer_to_channel.get(addr) now = time.time() if channel is None: self.peer_connect_waiters[addr] = [] channel = self.channel_number self.channel_number += 1 # bind channel await self.channel_bind(channel, addr) # update state self.channel_refresh_at[channel] = now + self.channel_refresh_time self.channel_to_peer[channel] = addr self.peer_to_channel[addr] = channel # notify waiters for waiter in self.peer_connect_waiters.pop(addr): waiter.set_result(None) elif now > self.channel_refresh_at[channel]: # refresh channel await self.channel_bind(channel, addr) # update state self.channel_refresh_at[channel] = now + self.channel_refresh_time header = struct.pack("!HH", channel, len(data)) self._send(header + data) def send_stun(self, message: stun.Message, addr: Tuple[str, int]) -> None: """ Send a STUN message to the TURN server. """ logger.debug("%s > %s %s", self, addr, message) self._send(bytes(message)) def __add_authentication(self, request: stun.Message) -> None: request.attributes["USERNAME"] = self.username request.attributes["NONCE"] = self.nonce request.attributes["REALM"] = self.realm request.add_message_integrity(self.integrity_key) class TurnClientTcpProtocol(TurnClientMixin, TurnStreamMixin, asyncio.Protocol): """ Protocol for handling TURN over TCP. """ def _send(self, data: bytes) -> None: self.transport.write(self._padded(data)) def __repr__(self) -> str: return "turn/tcp" class TurnClientUdpProtocol(TurnClientMixin, asyncio.DatagramProtocol): """ Protocol for handling TURN over UDP. """ def _send(self, data: bytes) -> None: self.transport.sendto(data) def __repr__(self) -> str: return "turn/udp" class TurnTransport: """ Behaves like a Datagram transport, but uses a TURN allocation. """ def __init__(self, protocol, inner_protocol) -> None: self.protocol = protocol self.__inner_protocol = inner_protocol self.__inner_protocol.receiver = protocol self.__relayed_address = None def close(self) -> None: """ Close the transport. After the TURN allocation has been deleted, the protocol's `connection_lost()` method will be called with None as its argument. """ asyncio.create_task(self.__inner_protocol.delete()) def get_extra_info(self, name: str, default: Any = None) -> Any: """ Return optional transport information. - `'related_address'`: the related address - `'sockname'`: the relayed address """ if name == "related_address": return self.__inner_protocol.transport.get_extra_info("sockname") elif name == "sockname": return self.__relayed_address return default def sendto(self, data: bytes, addr: Tuple[str, int]) -> None: """ Sends the `data` bytes to the remote peer given `addr`. This will bind a TURN channel as necessary. """ asyncio.create_task(self.__inner_protocol.send_data(data, addr)) async def _connect(self) -> None: self.__relayed_address = await self.__inner_protocol.connect() self.protocol.connection_made(self) async def create_turn_endpoint( protocol_factory: Callable[[], _ProtocolT], server_addr: Tuple[str, int], username: Optional[str], password: Optional[str], lifetime: int = DEFAULT_ALLOCATION_LIFETIME, channel_refresh_time: int = DEFAULT_CHANNEL_REFRESH_TIME, ssl: bool = False, transport: str = "udp", ) -> Tuple[TurnTransport, _ProtocolT]: """ Create datagram connection relayed over TURN. """ loop = asyncio.get_event_loop() inner_protocol: asyncio.BaseProtocol inner_transport: asyncio.BaseTransport if transport == "tcp": inner_transport, inner_protocol = await loop.create_connection( lambda: TurnClientTcpProtocol( server_addr, username=username, password=password, lifetime=lifetime, channel_refresh_time=channel_refresh_time, ), host=server_addr[0], port=server_addr[1], ssl=ssl, ) else: inner_transport, inner_protocol = await loop.create_datagram_endpoint( lambda: TurnClientUdpProtocol( server_addr, username=username, password=password, lifetime=lifetime, channel_refresh_time=channel_refresh_time, ), remote_addr=server_addr, ) sock = inner_transport.get_extra_info("socket") if sock is not None: sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, UDP_SOCKET_BUFFER_SIZE) try: protocol = protocol_factory() turn_transport = TurnTransport(protocol, inner_protocol) await turn_transport._connect() except Exception: inner_transport.close() raise return turn_transport, protocol