krishna-k's picture
Upload folder using huggingface_hub
06555b5 verified
raw
history blame contribute delete
13.8 kB
import multiprocessing
import random
from struct import pack, unpack_from
from typing import List, Tuple, Type, TypeVar, cast
from av import VideoFrame
from av.frame import Frame
from av.packet import Packet
from ..jitterbuffer import JitterFrame
from ..mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, convert_timebase
from ._vpx import ffi, lib
from .base import Decoder, Encoder
DEFAULT_BITRATE = 500000 # 500 kbps
MIN_BITRATE = 250000 # 250 kbps
MAX_BITRATE = 1500000 # 1.5 Mbps
MAX_FRAME_RATE = 30
PACKET_MAX = 1300
DESCRIPTOR_T = TypeVar("DESCRIPTOR_T", bound="VpxPayloadDescriptor")
def number_of_threads(pixels: int, cpus: int) -> int:
if pixels >= 1920 * 1080 and cpus > 8:
return 8
elif pixels > 1280 * 960 and cpus >= 6:
return 3
elif pixels > 640 * 480 and cpus >= 3:
return 2
else:
return 1
class VpxPayloadDescriptor:
def __init__(
self,
partition_start,
partition_id,
picture_id=None,
tl0picidx=None,
tid=None,
keyidx=None,
) -> None:
self.partition_start = partition_start
self.partition_id = partition_id
self.picture_id = picture_id
self.tl0picidx = tl0picidx
self.tid = tid
self.keyidx = keyidx
def __bytes__(self) -> bytes:
octet = (self.partition_start << 4) | self.partition_id
ext_octet = 0
if self.picture_id is not None:
ext_octet |= 1 << 7
if self.tl0picidx is not None:
ext_octet |= 1 << 6
if self.tid is not None:
ext_octet |= 1 << 5
if self.keyidx is not None:
ext_octet |= 1 << 4
if ext_octet:
data = pack("!BB", (1 << 7) | octet, ext_octet)
if self.picture_id is not None:
if self.picture_id < 128:
data += pack("!B", self.picture_id)
else:
data += pack("!H", (1 << 15) | self.picture_id)
if self.tl0picidx is not None:
data += pack("!B", self.tl0picidx)
if self.tid is not None or self.keyidx is not None:
t_k = 0
if self.tid is not None:
t_k |= (self.tid[0] << 6) | (self.tid[1] << 5)
if self.keyidx is not None:
t_k |= self.keyidx
data += pack("!B", t_k)
else:
data = pack("!B", octet)
return data
def __repr__(self) -> str:
return (
f"VpxPayloadDescriptor(S={self.partition_start}, "
f"PID={self.partition_id}, pic_id={self.picture_id})"
)
@classmethod
def parse(cls: Type[DESCRIPTOR_T], data: bytes) -> Tuple[DESCRIPTOR_T, bytes]:
if len(data) < 1:
raise ValueError("VPX descriptor is too short")
# first byte
octet = data[0]
extended = octet >> 7
partition_start = (octet >> 4) & 1
partition_id = octet & 0xF
picture_id = None
tl0picidx = None
tid = None
keyidx = None
pos = 1
# extended control bits
if extended:
if len(data) < pos + 1:
raise ValueError("VPX descriptor has truncated extended bits")
octet = data[pos]
ext_I = (octet >> 7) & 1
ext_L = (octet >> 6) & 1
ext_T = (octet >> 5) & 1
ext_K = (octet >> 4) & 1
pos += 1
# picture id
if ext_I:
if len(data) < pos + 1:
raise ValueError("VPX descriptor has truncated PictureID")
if data[pos] & 0x80:
if len(data) < pos + 2:
raise ValueError("VPX descriptor has truncated long PictureID")
picture_id = unpack_from("!H", data, pos)[0] & 0x7FFF
pos += 2
else:
picture_id = data[pos]
pos += 1
# unused
if ext_L:
if len(data) < pos + 1:
raise ValueError("VPX descriptor has truncated TL0PICIDX")
tl0picidx = data[pos]
pos += 1
if ext_T or ext_K:
if len(data) < pos + 1:
raise ValueError("VPX descriptor has truncated T/K")
t_k = data[pos]
if ext_T:
tid = ((t_k >> 6) & 3, (t_k >> 5) & 1)
if ext_K:
keyidx = t_k & 0x1F
pos += 1
obj = cls(
partition_start=partition_start,
partition_id=partition_id,
picture_id=picture_id,
tl0picidx=tl0picidx,
tid=tid,
keyidx=keyidx,
)
return obj, data[pos:]
def _vpx_assert(err: int) -> None:
if err != lib.VPX_CODEC_OK:
reason = ffi.string(lib.vpx_codec_err_to_string(err))
raise Exception("libvpx error: " + reason.decode("utf8"))
class Vp8Decoder(Decoder):
def __init__(self) -> None:
self.codec = ffi.new("vpx_codec_ctx_t *")
_vpx_assert(
lib.vpx_codec_dec_init(self.codec, lib.vpx_codec_vp8_dx(), ffi.NULL, 0)
)
ppcfg = ffi.new("vp8_postproc_cfg_t *")
ppcfg.post_proc_flag = lib.VP8_DEMACROBLOCK | lib.VP8_DEBLOCK
ppcfg.deblocking_level = 3
lib.vpx_codec_control_(self.codec, lib.VP8_SET_POSTPROC, ppcfg)
def __del__(self) -> None:
lib.vpx_codec_destroy(self.codec)
def decode(self, encoded_frame: JitterFrame) -> List[Frame]:
frames: List[Frame] = []
result = lib.vpx_codec_decode(
self.codec,
encoded_frame.data,
len(encoded_frame.data),
ffi.NULL,
lib.VPX_DL_REALTIME,
)
if result == lib.VPX_CODEC_OK:
it = ffi.new("vpx_codec_iter_t *")
while True:
img = lib.vpx_codec_get_frame(self.codec, it)
if not img:
break
assert img.fmt == lib.VPX_IMG_FMT_I420
frame = VideoFrame(width=img.d_w, height=img.d_h)
frame.pts = encoded_frame.timestamp
frame.time_base = VIDEO_TIME_BASE
for p in range(3):
i_stride = img.stride[p]
i_buf = ffi.buffer(img.planes[p], i_stride * img.d_h)
i_pos = 0
o_stride = frame.planes[p].line_size
o_buf = memoryview(cast(bytes, frame.planes[p]))
o_pos = 0
div = p and 2 or 1
for r in range(0, img.d_h // div):
o_buf[o_pos : o_pos + o_stride] = i_buf[
i_pos : i_pos + o_stride
]
i_pos += i_stride
o_pos += o_stride
frames.append(frame)
return frames
class Vp8Encoder(Encoder):
def __init__(self) -> None:
self.cx = lib.vpx_codec_vp8_cx()
self.cfg = ffi.new("vpx_codec_enc_cfg_t *")
lib.vpx_codec_enc_config_default(self.cx, self.cfg, 0)
self.buffer = bytearray(8000)
self.codec = None
self.picture_id = random.randint(0, (1 << 15) - 1)
self.timestamp_increment = VIDEO_CLOCK_RATE // MAX_FRAME_RATE
self.__target_bitrate = DEFAULT_BITRATE
self.__update_config_needed = False
def __del__(self) -> None:
if self.codec:
lib.vpx_codec_destroy(self.codec)
def encode(
self, frame: Frame, force_keyframe: bool = False
) -> Tuple[List[bytes], int]:
assert isinstance(frame, VideoFrame)
if frame.format.name != "yuv420p":
frame = frame.reformat(format="yuv420p")
if self.codec and (frame.width != self.cfg.g_w or frame.height != self.cfg.g_h):
lib.vpx_codec_destroy(self.codec)
self.codec = None
if not self.codec:
# create codec
self.codec = ffi.new("vpx_codec_ctx_t *")
self.cfg.g_timebase.num = 1
self.cfg.g_timebase.den = VIDEO_CLOCK_RATE
self.cfg.g_lag_in_frames = 0
self.cfg.g_threads = number_of_threads(
frame.width * frame.height, multiprocessing.cpu_count()
)
self.cfg.g_w = frame.width
self.cfg.g_h = frame.height
self.cfg.rc_resize_allowed = 0
self.cfg.rc_end_usage = lib.VPX_CBR
self.cfg.rc_min_quantizer = 2
self.cfg.rc_max_quantizer = 56
self.cfg.rc_undershoot_pct = 100
self.cfg.rc_overshoot_pct = 15
self.cfg.rc_buf_initial_sz = 500
self.cfg.rc_buf_optimal_sz = 600
self.cfg.rc_buf_sz = 1000
self.cfg.kf_mode = lib.VPX_KF_AUTO
self.cfg.kf_max_dist = 3000
self.__update_config()
_vpx_assert(lib.vpx_codec_enc_init(self.codec, self.cx, self.cfg, 0))
lib.vpx_codec_control_(
self.codec, lib.VP8E_SET_NOISE_SENSITIVITY, ffi.cast("int", 4)
)
lib.vpx_codec_control_(
self.codec, lib.VP8E_SET_STATIC_THRESHOLD, ffi.cast("int", 1)
)
lib.vpx_codec_control_(
self.codec, lib.VP8E_SET_CPUUSED, ffi.cast("int", -6)
)
lib.vpx_codec_control_(
self.codec,
lib.VP8E_SET_TOKEN_PARTITIONS,
ffi.cast("int", lib.VP8_ONE_TOKENPARTITION),
)
# create image on a dummy buffer, we will fill the pointers during encoding
self.image = ffi.new("vpx_image_t *")
lib.vpx_img_wrap(
self.image,
lib.VPX_IMG_FMT_I420,
frame.width,
frame.height,
1,
ffi.cast("void*", 1),
)
elif self.__update_config_needed:
self.__update_config()
_vpx_assert(lib.vpx_codec_enc_config_set(self.codec, self.cfg))
# setup image
for p in range(3):
self.image.planes[p] = ffi.cast("void*", frame.planes[p].buffer_ptr)
self.image.stride[p] = frame.planes[p].line_size
# encode frame
flags = 0
if force_keyframe:
flags |= lib.VPX_EFLAG_FORCE_KF
_vpx_assert(
lib.vpx_codec_encode(
self.codec,
self.image,
frame.pts,
self.timestamp_increment,
flags,
lib.VPX_DL_REALTIME,
)
)
it = ffi.new("vpx_codec_iter_t *")
length = 0
while True:
pkt = lib.vpx_codec_get_cx_data(self.codec, it)
if not pkt:
break
elif pkt.kind == lib.VPX_CODEC_CX_FRAME_PKT:
# resize buffer if needed
if length + pkt.data.frame.sz > len(self.buffer):
new_buffer = bytearray(length + pkt.data.frame.sz)
new_buffer[0:length] = self.buffer[0:length]
self.buffer = new_buffer
# append new data
self.buffer[length : length + pkt.data.frame.sz] = ffi.buffer(
pkt.data.frame.buf, pkt.data.frame.sz
)
length += pkt.data.frame.sz
# packetize
payloads = self._packetize(self.buffer[:length], self.picture_id)
timestamp = convert_timebase(frame.pts, frame.time_base, VIDEO_TIME_BASE)
self.picture_id = (self.picture_id + 1) % (1 << 15)
return payloads, timestamp
def pack(self, packet: Packet) -> Tuple[List[bytes], int]:
payloads = self._packetize(bytes(packet), self.picture_id)
timestamp = convert_timebase(packet.pts, packet.time_base, VIDEO_TIME_BASE)
self.picture_id = (self.picture_id + 1) % (1 << 15)
return payloads, timestamp
@property
def target_bitrate(self) -> int:
"""
Target bitrate in bits per second.
"""
return self.__target_bitrate
@target_bitrate.setter
def target_bitrate(self, bitrate: int) -> None:
bitrate = max(MIN_BITRATE, min(bitrate, MAX_BITRATE))
if bitrate != self.__target_bitrate:
self.__target_bitrate = bitrate
self.__update_config_needed = True
@classmethod
def _packetize(cls, buffer: bytes, picture_id: int) -> List[bytes]:
payloads = []
descr = VpxPayloadDescriptor(
partition_start=1, partition_id=0, picture_id=picture_id
)
length = len(buffer)
pos = 0
while pos < length:
descr_bytes = bytes(descr)
size = min(length - pos, PACKET_MAX - len(descr_bytes))
payloads.append(descr_bytes + buffer[pos : pos + size])
descr.partition_start = 0
pos += size
return payloads
def __update_config(self) -> None:
self.cfg.rc_target_bitrate = self.__target_bitrate // 1000
self.__update_config_needed = False
def vp8_depayload(payload: bytes) -> bytes:
descriptor, data = VpxPayloadDescriptor.parse(payload)
return data