|
import math
|
|
from enum import Enum
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from aiortc.utils import uint32_add, uint32_gt
|
|
|
|
BURST_DELTA_THRESHOLD_MS = 5
|
|
|
|
|
|
MAX_ADAPT_OFFSET_MS = 15
|
|
MIN_NUM_DELTAS = 60
|
|
|
|
|
|
DELTA_COUNTER_MAX = 1000
|
|
MIN_FRAME_PERIOD_HISTORY_LENGTH = 60
|
|
|
|
|
|
INTER_ARRIVAL_SHIFT = 26
|
|
TIMESTAMP_GROUP_LENGTH_MS = 5
|
|
TIMESTAMP_TO_MS = 1000.0 / (1 << INTER_ARRIVAL_SHIFT)
|
|
|
|
|
|
class BandwidthUsage(Enum):
|
|
NORMAL = 0
|
|
UNDERUSING = 1
|
|
OVERUSING = 2
|
|
|
|
|
|
class RateControlState(Enum):
|
|
HOLD = 0
|
|
INCREASE = 1
|
|
DECREASE = 2
|
|
|
|
|
|
class AimdRateControl:
|
|
def __init__(self) -> None:
|
|
self.avg_max_bitrate_kbps = None
|
|
self.var_max_bitrate_kbps = 0.4
|
|
self.current_bitrate = 30000000
|
|
self.current_bitrate_initialized = False
|
|
self.first_estimated_throughput_time: Optional[int] = None
|
|
self.last_change_ms: Optional[int] = None
|
|
self.near_max = False
|
|
self.latest_estimated_throughput = 30000000
|
|
self.rtt = 200
|
|
self.state = RateControlState.HOLD
|
|
|
|
def feedback_interval(self) -> int:
|
|
return 500
|
|
|
|
def set_estimate(self, bitrate: int, now_ms: int) -> None:
|
|
"""
|
|
For testing purposes.
|
|
"""
|
|
self.current_bitrate = self._clamp_bitrate(bitrate, bitrate)
|
|
self.current_bitrate_initialized = True
|
|
self.last_change_ms = now_ms
|
|
|
|
def update(
|
|
self,
|
|
bandwidth_usage: BandwidthUsage,
|
|
estimated_throughput: Optional[int],
|
|
now_ms: int,
|
|
) -> Optional[int]:
|
|
if not self.current_bitrate_initialized and estimated_throughput is not None:
|
|
if self.first_estimated_throughput_time is None:
|
|
self.first_estimated_throughput_time = now_ms
|
|
elif now_ms - self.first_estimated_throughput_time > 3000:
|
|
self.current_bitrate = estimated_throughput
|
|
self.current_bitrate_initialized = True
|
|
|
|
|
|
if (
|
|
not self.current_bitrate_initialized
|
|
and bandwidth_usage != BandwidthUsage.OVERUSING
|
|
):
|
|
return None
|
|
|
|
|
|
if (
|
|
bandwidth_usage == BandwidthUsage.NORMAL
|
|
and self.state == RateControlState.HOLD
|
|
):
|
|
self.last_change_ms = now_ms
|
|
self.state = RateControlState.INCREASE
|
|
elif bandwidth_usage == BandwidthUsage.OVERUSING:
|
|
self.state = RateControlState.DECREASE
|
|
elif bandwidth_usage == BandwidthUsage.UNDERUSING:
|
|
self.state = RateControlState.HOLD
|
|
|
|
|
|
new_bitrate = self.current_bitrate
|
|
if estimated_throughput is not None:
|
|
self.latest_estimated_throughput = estimated_throughput
|
|
else:
|
|
estimated_throughput = self.latest_estimated_throughput
|
|
estimated_throughput_kbps = estimated_throughput / 1000
|
|
|
|
|
|
if self.state == RateControlState.INCREASE:
|
|
|
|
|
|
if self.avg_max_bitrate_kbps is not None:
|
|
sigma_kbps = math.sqrt(
|
|
self.var_max_bitrate_kbps * self.avg_max_bitrate_kbps
|
|
)
|
|
if (
|
|
estimated_throughput_kbps
|
|
>= self.avg_max_bitrate_kbps + 3 * sigma_kbps
|
|
):
|
|
self.near_max = False
|
|
self.avg_max_bitrate_kbps = None
|
|
|
|
|
|
|
|
if self.near_max:
|
|
new_bitrate += self._additive_rate_increase(self.last_change_ms, now_ms)
|
|
else:
|
|
new_bitrate += self._multiplicative_rate_increase(
|
|
new_bitrate, self.last_change_ms, now_ms
|
|
)
|
|
self.last_change_ms = now_ms
|
|
elif self.state == RateControlState.DECREASE:
|
|
|
|
|
|
if self.avg_max_bitrate_kbps is not None:
|
|
sigma_kbps = math.sqrt(
|
|
self.var_max_bitrate_kbps * self.avg_max_bitrate_kbps
|
|
)
|
|
if (
|
|
estimated_throughput_kbps
|
|
< self.avg_max_bitrate_kbps - 3 * sigma_kbps
|
|
):
|
|
self.avg_max_bitrate_kbps = None
|
|
self._update_max_throughput_estimate(estimated_throughput_kbps)
|
|
|
|
self.near_max = True
|
|
new_bitrate = round(0.85 * estimated_throughput)
|
|
self.last_change_ms = now_ms
|
|
self.state = RateControlState.HOLD
|
|
|
|
self.current_bitrate = self._clamp_bitrate(new_bitrate, estimated_throughput)
|
|
return self.current_bitrate
|
|
|
|
def _additive_rate_increase(self, last_ms: int, now_ms: int) -> int:
|
|
return int((now_ms - last_ms) * self._near_max_rate_increase() / 1000)
|
|
|
|
def _clamp_bitrate(self, new_bitrate: int, estimated_throughput: int) -> int:
|
|
max_bitrate = max(int(1.5 * estimated_throughput) + 10000, self.current_bitrate)
|
|
return min(new_bitrate, max_bitrate)
|
|
|
|
def _multiplicative_rate_increase(
|
|
self, new_bitrate: int, last_ms: int, now_ms: int
|
|
) -> int:
|
|
alpha = 1.08
|
|
if last_ms is not None:
|
|
elapsed_ms = min(now_ms - last_ms, 1000)
|
|
alpha = pow(alpha, elapsed_ms / 1000)
|
|
return int(max((alpha - 1) * new_bitrate, 1000))
|
|
|
|
def _near_max_rate_increase(self) -> int:
|
|
bits_per_frame = self.current_bitrate / 30
|
|
packets_per_frame = math.ceil(bits_per_frame / (8 * 1200))
|
|
avg_packet_size_bits = bits_per_frame / packets_per_frame
|
|
|
|
response_time = self.rtt + 100
|
|
return max(4000, int((avg_packet_size_bits * 1000) / response_time))
|
|
|
|
def _update_max_throughput_estimate(self, estimated_throughput_kbps) -> None:
|
|
alpha = 0.05
|
|
if self.avg_max_bitrate_kbps is None:
|
|
self.avg_max_bitrate_kbps = estimated_throughput_kbps
|
|
else:
|
|
self.avg_max_bitrate_kbps = (
|
|
1 - alpha
|
|
) * self.avg_max_bitrate_kbps + alpha * estimated_throughput_kbps
|
|
|
|
norm = max(1, self.avg_max_bitrate_kbps)
|
|
self.var_max_bitrate_kbps = (1 - alpha) * self.var_max_bitrate_kbps + alpha * (
|
|
(self.avg_max_bitrate_kbps - estimated_throughput_kbps) ** 2
|
|
) / norm
|
|
self.var_max_bitrate_kbps = max(0.4, min(self.var_max_bitrate_kbps, 2.5))
|
|
|
|
|
|
class TimestampGroup:
|
|
def __init__(self, timestamp: Optional[int] = None) -> None:
|
|
self.arrival_time: Optional[int] = None
|
|
self.first_timestamp = timestamp
|
|
self.last_timestamp = timestamp
|
|
self.size = 0
|
|
|
|
|
|
class InterArrivalDelta:
|
|
def __init__(self, timestamp: int, arrival_time: int, size: int) -> None:
|
|
self.timestamp = timestamp
|
|
self.arrival_time = arrival_time
|
|
self.size = size
|
|
|
|
|
|
class InterArrival:
|
|
"""
|
|
Inter-arrival time and size filter.
|
|
|
|
Adapted from the webrtc.org codebase.
|
|
"""
|
|
|
|
def __init__(self, group_length: int, timestamp_to_ms: float) -> None:
|
|
self.group_length = group_length
|
|
self.timestamp_to_ms = timestamp_to_ms
|
|
self.current_group: Optional[TimestampGroup] = None
|
|
self.previous_group: Optional[TimestampGroup] = None
|
|
|
|
def compute_deltas(
|
|
self, timestamp: int, arrival_time: int, packet_size: int
|
|
) -> Optional[InterArrivalDelta]:
|
|
deltas = None
|
|
if self.current_group is None:
|
|
self.current_group = TimestampGroup(timestamp)
|
|
elif self.packet_out_of_order(timestamp):
|
|
return deltas
|
|
elif self.new_timestamp_group(timestamp, arrival_time):
|
|
if self.previous_group is not None:
|
|
deltas = InterArrivalDelta(
|
|
timestamp=uint32_add(
|
|
self.current_group.last_timestamp,
|
|
-self.previous_group.last_timestamp,
|
|
),
|
|
arrival_time=(
|
|
self.current_group.arrival_time
|
|
- self.previous_group.arrival_time
|
|
),
|
|
size=self.current_group.size - self.previous_group.size,
|
|
)
|
|
|
|
|
|
self.previous_group = self.current_group
|
|
self.current_group = TimestampGroup(timestamp=timestamp)
|
|
elif uint32_gt(timestamp, self.current_group.last_timestamp):
|
|
self.current_group.last_timestamp = timestamp
|
|
|
|
self.current_group.size += packet_size
|
|
self.current_group.arrival_time = arrival_time
|
|
|
|
return deltas
|
|
|
|
def belongs_to_burst(self, timestamp: int, arrival_time: int) -> bool:
|
|
timestamp_delta = uint32_add(timestamp, -self.current_group.last_timestamp)
|
|
timestamp_delta_ms = round(self.timestamp_to_ms * timestamp_delta)
|
|
arrival_time_delta = arrival_time - self.current_group.arrival_time
|
|
return timestamp_delta_ms == 0 or (
|
|
(arrival_time_delta - timestamp_delta_ms) < 0
|
|
and arrival_time_delta <= BURST_DELTA_THRESHOLD_MS
|
|
)
|
|
|
|
def new_timestamp_group(self, timestamp: int, arrival_time: int) -> bool:
|
|
if self.belongs_to_burst(timestamp, arrival_time):
|
|
return False
|
|
else:
|
|
timestamp_delta = uint32_add(timestamp, -self.current_group.first_timestamp)
|
|
return timestamp_delta > self.group_length
|
|
|
|
def packet_out_of_order(self, timestamp: int) -> bool:
|
|
timestamp_delta = uint32_add(timestamp, -self.current_group.first_timestamp)
|
|
return timestamp_delta >= 0x80000000
|
|
|
|
|
|
class OveruseDetector:
|
|
"""
|
|
Bandwidth overuse detector.
|
|
|
|
Adapted from the webrtc.org codebase.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.hypothesis = BandwidthUsage.NORMAL
|
|
self.last_update_ms: Optional[int] = None
|
|
self.k_up = 0.0087
|
|
self.k_down = 0.039
|
|
self.overuse_counter = 0
|
|
self.overuse_time: Optional[float] = None
|
|
self.overuse_time_threshold = 10
|
|
self.previous_offset = 0.0
|
|
self.threshold = 12.5
|
|
|
|
def detect(
|
|
self, offset: float, timestamp_delta_ms: float, num_of_deltas: int, now_ms: int
|
|
) -> BandwidthUsage:
|
|
if num_of_deltas < 2:
|
|
return BandwidthUsage.NORMAL
|
|
|
|
T = min(num_of_deltas, MIN_NUM_DELTAS) * offset
|
|
if T > self.threshold:
|
|
if self.overuse_time is None:
|
|
self.overuse_time = timestamp_delta_ms / 2
|
|
else:
|
|
self.overuse_time += timestamp_delta_ms
|
|
self.overuse_counter += 1
|
|
|
|
if (
|
|
self.overuse_time > self.overuse_time_threshold
|
|
and self.overuse_counter > 1
|
|
and offset >= self.previous_offset
|
|
):
|
|
self.overuse_counter = 0
|
|
self.overuse_time = 0
|
|
self.hypothesis = BandwidthUsage.OVERUSING
|
|
elif T < -self.threshold:
|
|
self.overuse_counter = 0
|
|
self.overuse_time = None
|
|
self.hypothesis = BandwidthUsage.UNDERUSING
|
|
else:
|
|
self.overuse_counter = 0
|
|
self.overuse_time = None
|
|
self.hypothesis = BandwidthUsage.NORMAL
|
|
|
|
self.previous_offset = offset
|
|
self.update_threshold(T, now_ms)
|
|
return self.hypothesis
|
|
|
|
def state(self) -> BandwidthUsage:
|
|
return self.hypothesis
|
|
|
|
def update_threshold(self, modified_offset: float, now_ms: int) -> None:
|
|
if self.last_update_ms is None:
|
|
self.last_update_ms = now_ms
|
|
|
|
if abs(modified_offset) > self.threshold + MAX_ADAPT_OFFSET_MS:
|
|
self.last_update_ms = now_ms
|
|
return
|
|
|
|
k = self.k_down if abs(modified_offset) < self.threshold else self.k_up
|
|
time_delta_ms = min(now_ms - self.last_update_ms, 100)
|
|
self.threshold += k * (abs(modified_offset) - self.threshold) * time_delta_ms
|
|
self.threshold = max(6, min(self.threshold, 600))
|
|
self.last_update_ms = now_ms
|
|
|
|
|
|
class OveruseEstimator:
|
|
"""
|
|
Bandwidth overuse estimator.
|
|
|
|
Adapted from the webrtc.org codebase.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.E = [[100.0, 0.0], [0.0, 0.1]]
|
|
self._num_of_deltas = 0
|
|
self._offset = 0.0
|
|
self.previous_offset = 0.0
|
|
self.slope = 1 / 64
|
|
self.ts_delta_hist: List[float] = []
|
|
|
|
self.avg_noise = 0.0
|
|
self.var_noise = 50.0
|
|
self.process_noise = [1e-13, 1e-3]
|
|
|
|
def num_of_deltas(self) -> int:
|
|
return self._num_of_deltas
|
|
|
|
def offset(self) -> float:
|
|
return self._offset
|
|
|
|
def update(
|
|
self,
|
|
time_delta_ms: int,
|
|
timestamp_delta_ms: float,
|
|
size_delta: int,
|
|
current_hypothesis: BandwidthUsage,
|
|
now_ms: int,
|
|
):
|
|
min_frame_period = self.update_min_frame_period(timestamp_delta_ms)
|
|
t_ts_delta = time_delta_ms - timestamp_delta_ms
|
|
fs_delta = size_delta
|
|
|
|
self._num_of_deltas = min(self._num_of_deltas + 1, DELTA_COUNTER_MAX)
|
|
|
|
|
|
self.E[0][0] += self.process_noise[0]
|
|
self.E[1][1] += self.process_noise[1]
|
|
if (
|
|
current_hypothesis == BandwidthUsage.OVERUSING
|
|
and self._offset < self.previous_offset
|
|
) or (
|
|
current_hypothesis == BandwidthUsage.UNDERUSING
|
|
and self._offset > self.previous_offset
|
|
):
|
|
self.E[1][1] += 10 * self.process_noise[1]
|
|
|
|
h = [fs_delta, 1.0]
|
|
Eh = [
|
|
self.E[0][0] * h[0] + self.E[0][1] * h[1],
|
|
self.E[1][0] * h[0] + self.E[1][1] * h[1],
|
|
]
|
|
|
|
|
|
residual = t_ts_delta - self.slope * h[0] - self._offset
|
|
if current_hypothesis == BandwidthUsage.NORMAL:
|
|
max_residual = 3.0 * math.sqrt(self.var_noise)
|
|
if abs(residual) < max_residual:
|
|
self.update_noise_estimate(residual, min_frame_period)
|
|
else:
|
|
self.update_noise_estimate(
|
|
-max_residual if residual < 0 else max_residual, min_frame_period
|
|
)
|
|
|
|
denom = self.var_noise + h[0] * Eh[0] + h[1] * Eh[1]
|
|
K = [Eh[0] / denom, Eh[1] / denom]
|
|
|
|
IKh = [[1.0 - K[0] * h[0], -K[0] * h[1]], [-K[1] * h[0], 1.0 - K[1] * h[1]]]
|
|
e00 = self.E[0][0]
|
|
e01 = self.E[0][1]
|
|
|
|
|
|
self.E[0][0] = e00 * IKh[0][0] + self.E[1][0] * IKh[0][1]
|
|
self.E[0][1] = e01 * IKh[0][0] + self.E[1][1] * IKh[0][1]
|
|
self.E[1][0] = e00 * IKh[1][0] + self.E[1][0] * IKh[1][1]
|
|
self.E[1][1] = e01 * IKh[1][0] + self.E[1][1] * IKh[1][1]
|
|
|
|
self.previous_offset = self._offset
|
|
self.slope += K[0] * residual
|
|
self._offset += K[1] * residual
|
|
|
|
def update_min_frame_period(self, ts_delta: float) -> float:
|
|
min_frame_period = ts_delta
|
|
if len(self.ts_delta_hist) >= MIN_FRAME_PERIOD_HISTORY_LENGTH:
|
|
self.ts_delta_hist.pop(0)
|
|
|
|
for old_ts_delta in self.ts_delta_hist:
|
|
min_frame_period = min(old_ts_delta, min_frame_period)
|
|
|
|
self.ts_delta_hist.append(ts_delta)
|
|
return min_frame_period
|
|
|
|
def update_noise_estimate(self, residual: float, ts_delta: float) -> None:
|
|
alpha = 0.01
|
|
if self._num_of_deltas > 10 * 30:
|
|
alpha = 0.002
|
|
|
|
beta = pow(1 - alpha, ts_delta * 30.0 / 1000.0)
|
|
self.avg_noise = beta * self.avg_noise + (1 - beta) * residual
|
|
self.var_noise = (
|
|
beta * self.var_noise + (1 - beta) * (self.avg_noise - residual) ** 2
|
|
)
|
|
|
|
if self.var_noise < 1:
|
|
self.var_noise = 1
|
|
|
|
|
|
class RateBucket:
|
|
def __init__(self, count: int = 0, value: int = 0) -> None:
|
|
self.count = count
|
|
self.value = value
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return self.count == other.count and self.value == other.value
|
|
|
|
|
|
class RateCounter:
|
|
"""
|
|
Rate counter, which stores the amount received in 1ms buckets.
|
|
"""
|
|
|
|
def __init__(self, window_size: int, scale: int = 8000) -> None:
|
|
self._origin_index = 0
|
|
self._origin_ms: Optional[int] = None
|
|
self._scale = scale
|
|
self._window_size = window_size
|
|
self.reset()
|
|
|
|
def add(self, value: int, now_ms: int) -> None:
|
|
if self._origin_ms is None:
|
|
self._origin_ms = now_ms
|
|
else:
|
|
self._erase_old(now_ms)
|
|
|
|
index = (self._origin_index + now_ms - self._origin_ms) % self._window_size
|
|
self._buckets[index].count += 1
|
|
self._buckets[index].value += value
|
|
self._total.count += 1
|
|
self._total.value += value
|
|
|
|
def rate(self, now_ms: int) -> Optional[int]:
|
|
if self._origin_ms is not None:
|
|
self._erase_old(now_ms)
|
|
active_window_size = now_ms - self._origin_ms + 1
|
|
if self._total.count > 0 and active_window_size > 1:
|
|
return round(self._scale * self._total.value / active_window_size)
|
|
return None
|
|
|
|
def reset(self) -> None:
|
|
self._buckets = [RateBucket() for i in range(self._window_size)]
|
|
self._origin_index = 0
|
|
self._origin_ms = None
|
|
self._total = RateBucket()
|
|
|
|
def _erase_old(self, now_ms: int) -> None:
|
|
new_origin_ms = now_ms - self._window_size + 1
|
|
while self._origin_ms < new_origin_ms:
|
|
bucket = self._buckets[self._origin_index]
|
|
self._total.count -= bucket.count
|
|
self._total.value -= bucket.value
|
|
bucket.count = 0
|
|
bucket.value = 0
|
|
|
|
self._origin_index = (self._origin_index + 1) % self._window_size
|
|
self._origin_ms += 1
|
|
|
|
|
|
class RemoteBitrateEstimator:
|
|
def __init__(self) -> None:
|
|
self.incoming_bitrate = RateCounter(1000, 8000)
|
|
self.incoming_bitrate_initialized = True
|
|
self.inter_arrival = InterArrival(
|
|
(TIMESTAMP_GROUP_LENGTH_MS << INTER_ARRIVAL_SHIFT) // 1000, TIMESTAMP_TO_MS
|
|
)
|
|
self.estimator = OveruseEstimator()
|
|
self.detector = OveruseDetector()
|
|
self.rate_control = AimdRateControl()
|
|
self.last_update_ms: Optional[int] = None
|
|
self.ssrcs: Dict[int, int] = {}
|
|
|
|
def add(
|
|
self, arrival_time_ms: int, abs_send_time: int, payload_size: int, ssrc: int
|
|
) -> Optional[Tuple[int, List[int]]]:
|
|
timestamp = abs_send_time << 8
|
|
update_estimate = False
|
|
|
|
|
|
self.ssrcs[ssrc] = arrival_time_ms
|
|
|
|
|
|
if self.incoming_bitrate.rate(arrival_time_ms) is not None:
|
|
self.incoming_bitrate_initialized = True
|
|
elif self.incoming_bitrate_initialized:
|
|
self.incoming_bitrate.reset()
|
|
self.incoming_bitrate_initialized = False
|
|
self.incoming_bitrate.add(payload_size, arrival_time_ms)
|
|
|
|
|
|
deltas = self.inter_arrival.compute_deltas(
|
|
timestamp, arrival_time_ms, payload_size
|
|
)
|
|
if deltas is not None:
|
|
timestamp_delta_ms = deltas.timestamp * TIMESTAMP_TO_MS
|
|
self.estimator.update(
|
|
deltas.arrival_time,
|
|
timestamp_delta_ms,
|
|
deltas.size,
|
|
self.detector.state(),
|
|
arrival_time_ms,
|
|
)
|
|
self.detector.detect(
|
|
self.estimator.offset(),
|
|
timestamp_delta_ms,
|
|
self.estimator.num_of_deltas(),
|
|
arrival_time_ms,
|
|
)
|
|
|
|
if not update_estimate:
|
|
if (
|
|
self.last_update_ms is None
|
|
or (arrival_time_ms - self.last_update_ms)
|
|
> self.rate_control.feedback_interval()
|
|
):
|
|
update_estimate = True
|
|
elif self.detector.state() == BandwidthUsage.OVERUSING:
|
|
update_estimate = True
|
|
|
|
if update_estimate:
|
|
target_bitrate = self.rate_control.update(
|
|
self.detector.state(),
|
|
self.incoming_bitrate.rate(arrival_time_ms),
|
|
arrival_time_ms,
|
|
)
|
|
if target_bitrate is not None:
|
|
self.last_update_ms = arrival_time_ms
|
|
return target_bitrate, list(self.ssrcs.keys())
|
|
|
|
return None
|
|
|