krishna-k's picture
Upload folder using huggingface_hub
06555b5 verified
raw
history blame contribute delete
21.3 kB
import math
from enum import Enum
from typing import Dict, List, Optional, Tuple
from aiortc.utils import uint32_add, uint32_gt
BURST_DELTA_THRESHOLD_MS = 5
# overuse detector
MAX_ADAPT_OFFSET_MS = 15
MIN_NUM_DELTAS = 60
# overuse estimator
DELTA_COUNTER_MAX = 1000
MIN_FRAME_PERIOD_HISTORY_LENGTH = 60
# abs-send-time estimator
INTER_ARRIVAL_SHIFT = 26
TIMESTAMP_GROUP_LENGTH_MS = 5
TIMESTAMP_TO_MS = 1000.0 / (1 << INTER_ARRIVAL_SHIFT)
class BandwidthUsage(Enum):
NORMAL = 0
UNDERUSING = 1
OVERUSING = 2
class RateControlState(Enum):
HOLD = 0
INCREASE = 1
DECREASE = 2
class AimdRateControl:
def __init__(self) -> None:
self.avg_max_bitrate_kbps = None
self.var_max_bitrate_kbps = 0.4
self.current_bitrate = 30000000
self.current_bitrate_initialized = False
self.first_estimated_throughput_time: Optional[int] = None
self.last_change_ms: Optional[int] = None
self.near_max = False
self.latest_estimated_throughput = 30000000
self.rtt = 200
self.state = RateControlState.HOLD
def feedback_interval(self) -> int:
return 500
def set_estimate(self, bitrate: int, now_ms: int) -> None:
"""
For testing purposes.
"""
self.current_bitrate = self._clamp_bitrate(bitrate, bitrate)
self.current_bitrate_initialized = True
self.last_change_ms = now_ms
def update(
self,
bandwidth_usage: BandwidthUsage,
estimated_throughput: Optional[int],
now_ms: int,
) -> Optional[int]:
if not self.current_bitrate_initialized and estimated_throughput is not None:
if self.first_estimated_throughput_time is None:
self.first_estimated_throughput_time = now_ms
elif now_ms - self.first_estimated_throughput_time > 3000:
self.current_bitrate = estimated_throughput
self.current_bitrate_initialized = True
# wait for initialisation or overuse
if (
not self.current_bitrate_initialized
and bandwidth_usage != BandwidthUsage.OVERUSING
):
return None
# update state
if (
bandwidth_usage == BandwidthUsage.NORMAL
and self.state == RateControlState.HOLD
):
self.last_change_ms = now_ms
self.state = RateControlState.INCREASE
elif bandwidth_usage == BandwidthUsage.OVERUSING:
self.state = RateControlState.DECREASE
elif bandwidth_usage == BandwidthUsage.UNDERUSING:
self.state = RateControlState.HOLD
# helper variables
new_bitrate = self.current_bitrate
if estimated_throughput is not None:
self.latest_estimated_throughput = estimated_throughput
else:
estimated_throughput = self.latest_estimated_throughput
estimated_throughput_kbps = estimated_throughput / 1000
# update bitrate
if self.state == RateControlState.INCREASE:
# if the estimated throughput increases significantly,
# clear estimated max throughput
if self.avg_max_bitrate_kbps is not None:
sigma_kbps = math.sqrt(
self.var_max_bitrate_kbps * self.avg_max_bitrate_kbps
)
if (
estimated_throughput_kbps
>= self.avg_max_bitrate_kbps + 3 * sigma_kbps
):
self.near_max = False
self.avg_max_bitrate_kbps = None
# we use additive or multiplicative rate increase depending on whether
# we are close to the maximum throughput
if self.near_max:
new_bitrate += self._additive_rate_increase(self.last_change_ms, now_ms)
else:
new_bitrate += self._multiplicative_rate_increase(
new_bitrate, self.last_change_ms, now_ms
)
self.last_change_ms = now_ms
elif self.state == RateControlState.DECREASE:
# if the estimated throughput drops significantly,
# clear estimated max throughput
if self.avg_max_bitrate_kbps is not None:
sigma_kbps = math.sqrt(
self.var_max_bitrate_kbps * self.avg_max_bitrate_kbps
)
if (
estimated_throughput_kbps
< self.avg_max_bitrate_kbps - 3 * sigma_kbps
):
self.avg_max_bitrate_kbps = None
self._update_max_throughput_estimate(estimated_throughput_kbps)
self.near_max = True
new_bitrate = round(0.85 * estimated_throughput)
self.last_change_ms = now_ms
self.state = RateControlState.HOLD
self.current_bitrate = self._clamp_bitrate(new_bitrate, estimated_throughput)
return self.current_bitrate
def _additive_rate_increase(self, last_ms: int, now_ms: int) -> int:
return int((now_ms - last_ms) * self._near_max_rate_increase() / 1000)
def _clamp_bitrate(self, new_bitrate: int, estimated_throughput: int) -> int:
max_bitrate = max(int(1.5 * estimated_throughput) + 10000, self.current_bitrate)
return min(new_bitrate, max_bitrate)
def _multiplicative_rate_increase(
self, new_bitrate: int, last_ms: int, now_ms: int
) -> int:
alpha = 1.08
if last_ms is not None:
elapsed_ms = min(now_ms - last_ms, 1000)
alpha = pow(alpha, elapsed_ms / 1000)
return int(max((alpha - 1) * new_bitrate, 1000))
def _near_max_rate_increase(self) -> int:
bits_per_frame = self.current_bitrate / 30
packets_per_frame = math.ceil(bits_per_frame / (8 * 1200))
avg_packet_size_bits = bits_per_frame / packets_per_frame
response_time = self.rtt + 100
return max(4000, int((avg_packet_size_bits * 1000) / response_time))
def _update_max_throughput_estimate(self, estimated_throughput_kbps) -> None:
alpha = 0.05
if self.avg_max_bitrate_kbps is None:
self.avg_max_bitrate_kbps = estimated_throughput_kbps
else:
self.avg_max_bitrate_kbps = (
1 - alpha
) * self.avg_max_bitrate_kbps + alpha * estimated_throughput_kbps
norm = max(1, self.avg_max_bitrate_kbps)
self.var_max_bitrate_kbps = (1 - alpha) * self.var_max_bitrate_kbps + alpha * (
(self.avg_max_bitrate_kbps - estimated_throughput_kbps) ** 2
) / norm
self.var_max_bitrate_kbps = max(0.4, min(self.var_max_bitrate_kbps, 2.5))
class TimestampGroup:
def __init__(self, timestamp: Optional[int] = None) -> None:
self.arrival_time: Optional[int] = None
self.first_timestamp = timestamp
self.last_timestamp = timestamp
self.size = 0
class InterArrivalDelta:
def __init__(self, timestamp: int, arrival_time: int, size: int) -> None:
self.timestamp = timestamp
self.arrival_time = arrival_time
self.size = size
class InterArrival:
"""
Inter-arrival time and size filter.
Adapted from the webrtc.org codebase.
"""
def __init__(self, group_length: int, timestamp_to_ms: float) -> None:
self.group_length = group_length
self.timestamp_to_ms = timestamp_to_ms
self.current_group: Optional[TimestampGroup] = None
self.previous_group: Optional[TimestampGroup] = None
def compute_deltas(
self, timestamp: int, arrival_time: int, packet_size: int
) -> Optional[InterArrivalDelta]:
deltas = None
if self.current_group is None:
self.current_group = TimestampGroup(timestamp)
elif self.packet_out_of_order(timestamp):
return deltas
elif self.new_timestamp_group(timestamp, arrival_time):
if self.previous_group is not None:
deltas = InterArrivalDelta(
timestamp=uint32_add(
self.current_group.last_timestamp,
-self.previous_group.last_timestamp,
),
arrival_time=(
self.current_group.arrival_time
- self.previous_group.arrival_time
),
size=self.current_group.size - self.previous_group.size,
)
# shift groups
self.previous_group = self.current_group
self.current_group = TimestampGroup(timestamp=timestamp)
elif uint32_gt(timestamp, self.current_group.last_timestamp):
self.current_group.last_timestamp = timestamp
self.current_group.size += packet_size
self.current_group.arrival_time = arrival_time
return deltas
def belongs_to_burst(self, timestamp: int, arrival_time: int) -> bool:
timestamp_delta = uint32_add(timestamp, -self.current_group.last_timestamp)
timestamp_delta_ms = round(self.timestamp_to_ms * timestamp_delta)
arrival_time_delta = arrival_time - self.current_group.arrival_time
return timestamp_delta_ms == 0 or (
(arrival_time_delta - timestamp_delta_ms) < 0
and arrival_time_delta <= BURST_DELTA_THRESHOLD_MS
)
def new_timestamp_group(self, timestamp: int, arrival_time: int) -> bool:
if self.belongs_to_burst(timestamp, arrival_time):
return False
else:
timestamp_delta = uint32_add(timestamp, -self.current_group.first_timestamp)
return timestamp_delta > self.group_length
def packet_out_of_order(self, timestamp: int) -> bool:
timestamp_delta = uint32_add(timestamp, -self.current_group.first_timestamp)
return timestamp_delta >= 0x80000000
class OveruseDetector:
"""
Bandwidth overuse detector.
Adapted from the webrtc.org codebase.
"""
def __init__(self) -> None:
self.hypothesis = BandwidthUsage.NORMAL
self.last_update_ms: Optional[int] = None
self.k_up = 0.0087
self.k_down = 0.039
self.overuse_counter = 0
self.overuse_time: Optional[float] = None
self.overuse_time_threshold = 10
self.previous_offset = 0.0
self.threshold = 12.5
def detect(
self, offset: float, timestamp_delta_ms: float, num_of_deltas: int, now_ms: int
) -> BandwidthUsage:
if num_of_deltas < 2:
return BandwidthUsage.NORMAL
T = min(num_of_deltas, MIN_NUM_DELTAS) * offset
if T > self.threshold:
if self.overuse_time is None:
self.overuse_time = timestamp_delta_ms / 2
else:
self.overuse_time += timestamp_delta_ms
self.overuse_counter += 1
if (
self.overuse_time > self.overuse_time_threshold
and self.overuse_counter > 1
and offset >= self.previous_offset
):
self.overuse_counter = 0
self.overuse_time = 0
self.hypothesis = BandwidthUsage.OVERUSING
elif T < -self.threshold:
self.overuse_counter = 0
self.overuse_time = None
self.hypothesis = BandwidthUsage.UNDERUSING
else:
self.overuse_counter = 0
self.overuse_time = None
self.hypothesis = BandwidthUsage.NORMAL
self.previous_offset = offset
self.update_threshold(T, now_ms)
return self.hypothesis
def state(self) -> BandwidthUsage:
return self.hypothesis
def update_threshold(self, modified_offset: float, now_ms: int) -> None:
if self.last_update_ms is None:
self.last_update_ms = now_ms
if abs(modified_offset) > self.threshold + MAX_ADAPT_OFFSET_MS:
self.last_update_ms = now_ms
return
k = self.k_down if abs(modified_offset) < self.threshold else self.k_up
time_delta_ms = min(now_ms - self.last_update_ms, 100)
self.threshold += k * (abs(modified_offset) - self.threshold) * time_delta_ms
self.threshold = max(6, min(self.threshold, 600))
self.last_update_ms = now_ms
class OveruseEstimator:
"""
Bandwidth overuse estimator.
Adapted from the webrtc.org codebase.
"""
def __init__(self) -> None:
self.E = [[100.0, 0.0], [0.0, 0.1]]
self._num_of_deltas = 0
self._offset = 0.0
self.previous_offset = 0.0
self.slope = 1 / 64
self.ts_delta_hist: List[float] = []
self.avg_noise = 0.0
self.var_noise = 50.0
self.process_noise = [1e-13, 1e-3]
def num_of_deltas(self) -> int:
return self._num_of_deltas
def offset(self) -> float:
return self._offset
def update(
self,
time_delta_ms: int,
timestamp_delta_ms: float,
size_delta: int,
current_hypothesis: BandwidthUsage,
now_ms: int,
):
min_frame_period = self.update_min_frame_period(timestamp_delta_ms)
t_ts_delta = time_delta_ms - timestamp_delta_ms
fs_delta = size_delta
self._num_of_deltas = min(self._num_of_deltas + 1, DELTA_COUNTER_MAX)
# update Kalman filter
self.E[0][0] += self.process_noise[0]
self.E[1][1] += self.process_noise[1]
if (
current_hypothesis == BandwidthUsage.OVERUSING
and self._offset < self.previous_offset
) or (
current_hypothesis == BandwidthUsage.UNDERUSING
and self._offset > self.previous_offset
):
self.E[1][1] += 10 * self.process_noise[1]
h = [fs_delta, 1.0]
Eh = [
self.E[0][0] * h[0] + self.E[0][1] * h[1],
self.E[1][0] * h[0] + self.E[1][1] * h[1],
]
# update noise estimate
residual = t_ts_delta - self.slope * h[0] - self._offset
if current_hypothesis == BandwidthUsage.NORMAL:
max_residual = 3.0 * math.sqrt(self.var_noise)
if abs(residual) < max_residual:
self.update_noise_estimate(residual, min_frame_period)
else:
self.update_noise_estimate(
-max_residual if residual < 0 else max_residual, min_frame_period
)
denom = self.var_noise + h[0] * Eh[0] + h[1] * Eh[1]
K = [Eh[0] / denom, Eh[1] / denom]
IKh = [[1.0 - K[0] * h[0], -K[0] * h[1]], [-K[1] * h[0], 1.0 - K[1] * h[1]]]
e00 = self.E[0][0]
e01 = self.E[0][1]
# update state
self.E[0][0] = e00 * IKh[0][0] + self.E[1][0] * IKh[0][1]
self.E[0][1] = e01 * IKh[0][0] + self.E[1][1] * IKh[0][1]
self.E[1][0] = e00 * IKh[1][0] + self.E[1][0] * IKh[1][1]
self.E[1][1] = e01 * IKh[1][0] + self.E[1][1] * IKh[1][1]
self.previous_offset = self._offset
self.slope += K[0] * residual
self._offset += K[1] * residual
def update_min_frame_period(self, ts_delta: float) -> float:
min_frame_period = ts_delta
if len(self.ts_delta_hist) >= MIN_FRAME_PERIOD_HISTORY_LENGTH:
self.ts_delta_hist.pop(0)
for old_ts_delta in self.ts_delta_hist:
min_frame_period = min(old_ts_delta, min_frame_period)
self.ts_delta_hist.append(ts_delta)
return min_frame_period
def update_noise_estimate(self, residual: float, ts_delta: float) -> None:
alpha = 0.01
if self._num_of_deltas > 10 * 30:
alpha = 0.002
beta = pow(1 - alpha, ts_delta * 30.0 / 1000.0)
self.avg_noise = beta * self.avg_noise + (1 - beta) * residual
self.var_noise = (
beta * self.var_noise + (1 - beta) * (self.avg_noise - residual) ** 2
)
if self.var_noise < 1:
self.var_noise = 1
class RateBucket:
def __init__(self, count: int = 0, value: int = 0) -> None:
self.count = count
self.value = value
def __eq__(self, other) -> bool:
return self.count == other.count and self.value == other.value
class RateCounter:
"""
Rate counter, which stores the amount received in 1ms buckets.
"""
def __init__(self, window_size: int, scale: int = 8000) -> None:
self._origin_index = 0
self._origin_ms: Optional[int] = None
self._scale = scale
self._window_size = window_size
self.reset()
def add(self, value: int, now_ms: int) -> None:
if self._origin_ms is None:
self._origin_ms = now_ms
else:
self._erase_old(now_ms)
index = (self._origin_index + now_ms - self._origin_ms) % self._window_size
self._buckets[index].count += 1
self._buckets[index].value += value
self._total.count += 1
self._total.value += value
def rate(self, now_ms: int) -> Optional[int]:
if self._origin_ms is not None:
self._erase_old(now_ms)
active_window_size = now_ms - self._origin_ms + 1
if self._total.count > 0 and active_window_size > 1:
return round(self._scale * self._total.value / active_window_size)
return None
def reset(self) -> None:
self._buckets = [RateBucket() for i in range(self._window_size)]
self._origin_index = 0
self._origin_ms = None
self._total = RateBucket()
def _erase_old(self, now_ms: int) -> None:
new_origin_ms = now_ms - self._window_size + 1
while self._origin_ms < new_origin_ms:
bucket = self._buckets[self._origin_index]
self._total.count -= bucket.count
self._total.value -= bucket.value
bucket.count = 0
bucket.value = 0
self._origin_index = (self._origin_index + 1) % self._window_size
self._origin_ms += 1
class RemoteBitrateEstimator:
def __init__(self) -> None:
self.incoming_bitrate = RateCounter(1000, 8000)
self.incoming_bitrate_initialized = True
self.inter_arrival = InterArrival(
(TIMESTAMP_GROUP_LENGTH_MS << INTER_ARRIVAL_SHIFT) // 1000, TIMESTAMP_TO_MS
)
self.estimator = OveruseEstimator()
self.detector = OveruseDetector()
self.rate_control = AimdRateControl()
self.last_update_ms: Optional[int] = None
self.ssrcs: Dict[int, int] = {}
def add(
self, arrival_time_ms: int, abs_send_time: int, payload_size: int, ssrc: int
) -> Optional[Tuple[int, List[int]]]:
timestamp = abs_send_time << 8
update_estimate = False
# make note of SSRC
self.ssrcs[ssrc] = arrival_time_ms
# update incoming bitrate
if self.incoming_bitrate.rate(arrival_time_ms) is not None:
self.incoming_bitrate_initialized = True
elif self.incoming_bitrate_initialized:
self.incoming_bitrate.reset()
self.incoming_bitrate_initialized = False
self.incoming_bitrate.add(payload_size, arrival_time_ms)
# calculate inter-arrival deltas
deltas = self.inter_arrival.compute_deltas(
timestamp, arrival_time_ms, payload_size
)
if deltas is not None:
timestamp_delta_ms = deltas.timestamp * TIMESTAMP_TO_MS
self.estimator.update(
deltas.arrival_time,
timestamp_delta_ms,
deltas.size,
self.detector.state(),
arrival_time_ms,
)
self.detector.detect(
self.estimator.offset(),
timestamp_delta_ms,
self.estimator.num_of_deltas(),
arrival_time_ms,
)
if not update_estimate:
if (
self.last_update_ms is None
or (arrival_time_ms - self.last_update_ms)
> self.rate_control.feedback_interval()
):
update_estimate = True
elif self.detector.state() == BandwidthUsage.OVERUSING:
update_estimate = True
if update_estimate:
target_bitrate = self.rate_control.update(
self.detector.state(),
self.incoming_bitrate.rate(arrival_time_ms),
arrival_time_ms,
)
if target_bitrate is not None:
self.last_update_ms = arrival_time_ms
return target_bitrate, list(self.ssrcs.keys())
return None