Spaces:
Runtime error
Runtime error
import copy | |
import glob | |
import os | |
import os.path as osp | |
import random | |
from functools import lru_cache | |
import cv2 | |
import decord | |
import numpy as np | |
import skvideo.io | |
import torch | |
import torchvision | |
from decord import VideoReader, cpu, gpu | |
from tqdm import tqdm | |
random.seed(42) | |
decord.bridge.set_bridge("torch") | |
def get_spatial_fragments( | |
video, | |
fragments_h=7, | |
fragments_w=7, | |
fsize_h=32, | |
fsize_w=32, | |
aligned=32, | |
nfrags=1, | |
random=False, | |
random_upsample=False, | |
fallback_type="upsample", | |
upsample=-1, | |
**kwargs, | |
): | |
if upsample > 0: | |
old_h, old_w = video.shape[-2], video.shape[-1] | |
if old_h >= old_w: | |
w = upsample | |
h = int(upsample * old_h / old_w) | |
else: | |
h = upsample | |
w = int(upsample * old_w / old_h) | |
video = get_resized_video(video, h, w) | |
size_h = fragments_h * fsize_h | |
size_w = fragments_w * fsize_w | |
## video: [C,T,H,W] | |
## situation for images | |
if video.shape[1] == 1: | |
aligned = 1 | |
dur_t, res_h, res_w = video.shape[-3:] | |
ratio = min(res_h / size_h, res_w / size_w) | |
if fallback_type == "upsample" and ratio < 1: | |
ovideo = video | |
video = torch.nn.functional.interpolate( | |
video / 255.0, scale_factor=1 / ratio, mode="bilinear" | |
) | |
video = (video * 255.0).type_as(ovideo) | |
if random_upsample: | |
randratio = random.random() * 0.5 + 1 | |
video = torch.nn.functional.interpolate( | |
video / 255.0, scale_factor=randratio, mode="bilinear" | |
) | |
video = (video * 255.0).type_as(ovideo) | |
assert dur_t % aligned == 0, "Please provide match vclip and align index" | |
size = size_h, size_w | |
## make sure that sampling will not run out of the picture | |
hgrids = torch.LongTensor( | |
[min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)] | |
) | |
wgrids = torch.LongTensor( | |
[min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)] | |
) | |
hlength, wlength = res_h // fragments_h, res_w // fragments_w | |
if random: | |
print("This part is deprecated. Please remind that.") | |
if res_h > fsize_h: | |
rnd_h = torch.randint( | |
res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) | |
) | |
else: | |
rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() | |
if res_w > fsize_w: | |
rnd_w = torch.randint( | |
res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) | |
) | |
else: | |
rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() | |
else: | |
if hlength > fsize_h: | |
rnd_h = torch.randint( | |
hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) | |
) | |
else: | |
rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() | |
if wlength > fsize_w: | |
rnd_w = torch.randint( | |
wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) | |
) | |
else: | |
rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() | |
target_video = torch.zeros(video.shape[:-2] + size).to(video.device) | |
# target_videos = [] | |
for i, hs in enumerate(hgrids): | |
for j, ws in enumerate(wgrids): | |
for t in range(dur_t // aligned): | |
t_s, t_e = t * aligned, (t + 1) * aligned | |
h_s, h_e = i * fsize_h, (i + 1) * fsize_h | |
w_s, w_e = j * fsize_w, (j + 1) * fsize_w | |
if random: | |
h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h | |
w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w | |
else: | |
h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h | |
w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w | |
target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[ | |
:, t_s:t_e, h_so:h_eo, w_so:w_eo | |
] | |
# target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo]) | |
# target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6) | |
# target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments | |
return target_video | |
def get_resize_function(size_h, size_w, target_ratio=1, random_crop=False): | |
if random_crop: | |
return torchvision.transforms.RandomResizedCrop( | |
(size_h, size_w), scale=(0.40, 1.0) | |
) | |
if target_ratio > 1: | |
size_h = int(target_ratio * size_w) | |
assert size_h > size_w | |
elif target_ratio < 1: | |
size_w = int(size_h / target_ratio) | |
assert size_w > size_h | |
return torchvision.transforms.Resize((size_h, size_w)) | |
def get_resized_video( | |
video, size_h=224, size_w=224, random_crop=False, arp=False, **kwargs, | |
): | |
video = video.permute(1, 0, 2, 3) | |
resize_opt = get_resize_function( | |
size_h, size_w, video.shape[-2] / video.shape[-1] if arp else 1, random_crop | |
) | |
video = resize_opt(video).permute(1, 0, 2, 3) | |
return video | |
def get_arp_resized_video( | |
video, short_edge=224, train=False, **kwargs, | |
): | |
if train: ## if during training, will random crop into square and then resize | |
res_h, res_w = video.shape[-2:] | |
ori_short_edge = min(video.shape[-2:]) | |
if res_h > ori_short_edge: | |
rnd_h = random.randrange(res_h - ori_short_edge) | |
video = video[..., rnd_h : rnd_h + ori_short_edge, :] | |
elif res_w > ori_short_edge: | |
rnd_w = random.randrange(res_w - ori_short_edge) | |
video = video[..., :, rnd_h : rnd_h + ori_short_edge] | |
ori_short_edge = min(video.shape[-2:]) | |
scale_factor = short_edge / ori_short_edge | |
ovideo = video | |
video = torch.nn.functional.interpolate( | |
video / 255.0, scale_factors=scale_factor, mode="bilinear" | |
) | |
video = (video * 255.0).type_as(ovideo) | |
return video | |
def get_arp_fragment_video( | |
video, short_fragments=7, fsize=32, train=False, **kwargs, | |
): | |
if ( | |
train | |
): ## if during training, will random crop into square and then get fragments | |
res_h, res_w = video.shape[-2:] | |
ori_short_edge = min(video.shape[-2:]) | |
if res_h > ori_short_edge: | |
rnd_h = random.randrange(res_h - ori_short_edge) | |
video = video[..., rnd_h : rnd_h + ori_short_edge, :] | |
elif res_w > ori_short_edge: | |
rnd_w = random.randrange(res_w - ori_short_edge) | |
video = video[..., :, rnd_h : rnd_h + ori_short_edge] | |
kwargs["fsize_h"], kwargs["fsize_w"] = fsize, fsize | |
res_h, res_w = video.shape[-2:] | |
if res_h > res_w: | |
kwargs["fragments_w"] = short_fragments | |
kwargs["fragments_h"] = int(short_fragments * res_h / res_w) | |
else: | |
kwargs["fragments_h"] = short_fragments | |
kwargs["fragments_w"] = int(short_fragments * res_w / res_h) | |
return get_spatial_fragments(video, **kwargs) | |
def get_cropped_video( | |
video, size_h=224, size_w=224, **kwargs, | |
): | |
kwargs["fragments_h"], kwargs["fragments_w"] = 1, 1 | |
kwargs["fsize_h"], kwargs["fsize_w"] = size_h, size_w | |
return get_spatial_fragments(video, **kwargs) | |
def get_single_view( | |
video, sample_type="aesthetic", **kwargs, | |
): | |
if sample_type.startswith("aesthetic"): | |
video = get_resized_video(video, **kwargs) | |
elif sample_type.startswith("technical"): | |
video = get_spatial_fragments(video, **kwargs) | |
elif sample_type.startswith("semantic"): | |
video = get_resized_video(video, **kwargs) | |
elif sample_type == "original": | |
return video | |
return video | |
def spatial_temporal_view_decomposition( | |
video_path, sample_types, samplers, is_train=False, augment=False, | |
): | |
video = {} | |
if torch.is_tensor(video_path): | |
all_frame_inds = [] | |
frame_inds = {} | |
for stype in samplers: | |
frame_inds[stype] = samplers[stype](video_path.shape[0], is_train) | |
all_frame_inds.append(frame_inds[stype]) | |
### Each frame is only decoded one time!!! | |
all_frame_inds = np.concatenate(all_frame_inds, 0) | |
frame_dict = {idx: video_path[idx].permute(1, 2, 0) for idx in np.unique(all_frame_inds)} | |
for stype in samplers: | |
imgs = [frame_dict[idx] for idx in frame_inds[stype]] | |
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2) | |
else: | |
if video_path.endswith(".yuv"): | |
print("This part will be deprecated due to large memory cost.") | |
## This is only an adaptation to LIVE-Qualcomm | |
ovideo = skvideo.io.vread( | |
video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"} | |
) | |
for stype in samplers: | |
frame_inds = samplers[stype](ovideo.shape[0], is_train) | |
imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds] | |
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2) | |
del ovideo | |
else: | |
decord.bridge.set_bridge("torch") | |
vreader = VideoReader(video_path) | |
### Avoid duplicated video decoding!!! Important!!!! | |
all_frame_inds = [] | |
frame_inds = {} | |
for stype in samplers: | |
frame_inds[stype] = samplers[stype](len(vreader), is_train) | |
all_frame_inds.append(frame_inds[stype]) | |
### Each frame is only decoded one time!!! | |
all_frame_inds = np.concatenate(all_frame_inds, 0) | |
frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)} | |
for stype in samplers: | |
imgs = [frame_dict[idx] for idx in frame_inds[stype]] | |
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2) | |
sampled_video = {} | |
for stype, sopt in sample_types.items(): | |
sampled_video[stype] = get_single_view(video[stype], stype, **sopt) | |
return sampled_video, frame_inds | |
import random | |
import numpy as np | |
class UnifiedFrameSampler: | |
def __init__( | |
self, fsize_t, fragments_t, frame_interval=1, num_clips=1, drop_rate=0.0, | |
): | |
self.fragments_t = fragments_t | |
self.fsize_t = fsize_t | |
self.size_t = fragments_t * fsize_t | |
self.frame_interval = frame_interval | |
self.num_clips = num_clips | |
self.drop_rate = drop_rate | |
def get_frame_indices(self, num_frames, train=False): | |
tgrids = np.array( | |
[num_frames // self.fragments_t * i for i in range(self.fragments_t)], | |
dtype=np.int32, | |
) | |
tlength = num_frames // self.fragments_t | |
if tlength > self.fsize_t * self.frame_interval: | |
rnd_t = np.random.randint( | |
0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids) | |
) | |
else: | |
rnd_t = np.zeros(len(tgrids), dtype=np.int32) | |
ranges_t = ( | |
np.arange(self.fsize_t)[None, :] * self.frame_interval | |
+ rnd_t[:, None] | |
+ tgrids[:, None] | |
) | |
drop = random.sample( | |
list(range(self.fragments_t)), int(self.fragments_t * self.drop_rate) | |
) | |
dropped_ranges_t = [] | |
for i, rt in enumerate(ranges_t): | |
if i not in drop: | |
dropped_ranges_t.append(rt) | |
return np.concatenate(dropped_ranges_t) | |
def __call__(self, total_frames, train=False, start_index=0): | |
frame_inds = [] | |
for i in range(self.num_clips): | |
frame_inds += [self.get_frame_indices(total_frames)] | |
frame_inds = np.concatenate(frame_inds) | |
frame_inds = np.mod(frame_inds + start_index, total_frames) | |
return frame_inds.astype(np.int32) | |
class ViewDecompositionDataset(torch.utils.data.Dataset): | |
def __init__(self, opt): | |
## opt is a dictionary that includes options for video sampling | |
super().__init__() | |
self.weight = opt.get("weight", 0.5) | |
self.fully_supervised = opt.get("fully_supervised", False) | |
print("Fully supervised:", self.fully_supervised) | |
self.video_infos = [] | |
self.ann_file = opt["anno_file"] | |
self.data_prefix = opt["data_prefix"] | |
self.opt = opt | |
self.sample_types = opt["sample_types"] | |
self.data_backend = opt.get("data_backend", "disk") | |
self.augment = opt.get("augment", False) | |
if self.data_backend == "petrel": | |
from petrel_client import client | |
self.client = client.Client(enable_mc=True) | |
self.phase = opt["phase"] | |
self.crop = opt.get("random_crop", False) | |
self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) | |
self.std = torch.FloatTensor([58.395, 57.12, 57.375]) | |
self.mean_semantic = torch.FloatTensor([122.77, 116.75, 104.09]) | |
self.std_semantic = torch.FloatTensor([68.50, 66.63, 70.32]) | |
self.samplers = {} | |
for stype, sopt in opt["sample_types"].items(): | |
if "t_frag" not in sopt: | |
# resized temporal sampling for TQE in COVER | |
self.samplers[stype] = UnifiedFrameSampler( | |
sopt["clip_len"], sopt["num_clips"], sopt["frame_interval"] | |
) | |
else: | |
# temporal sampling for AQE in COVER | |
self.samplers[stype] = UnifiedFrameSampler( | |
sopt["clip_len"] // sopt["t_frag"], | |
sopt["t_frag"], | |
sopt["frame_interval"], | |
sopt["num_clips"], | |
) | |
print( | |
stype + " branch sampled frames:", | |
self.samplers[stype](240, self.phase == "train"), | |
) | |
if isinstance(self.ann_file, list): | |
self.video_infos = self.ann_file | |
else: | |
try: | |
with open(self.ann_file, "r") as fin: | |
for line in fin: | |
line_split = line.strip().split(",") | |
filename, a, t, label = line_split | |
if self.fully_supervised: | |
label = float(a), float(t), float(label) | |
else: | |
label = float(label) | |
filename = osp.join(self.data_prefix, filename) | |
self.video_infos.append(dict(filename=filename, label=label)) | |
except: | |
#### No Label Testing | |
video_filenames = [] | |
for (root, dirs, files) in os.walk(self.data_prefix, topdown=True): | |
for file in files: | |
if file.endswith(".mp4"): | |
video_filenames += [os.path.join(root, file)] | |
print(len(video_filenames)) | |
video_filenames = sorted(video_filenames) | |
for filename in video_filenames: | |
self.video_infos.append(dict(filename=filename, label=-1)) | |
def __getitem__(self, index): | |
video_info = self.video_infos[index] | |
filename = video_info["filename"] | |
label = video_info["label"] | |
try: | |
## Read Original Frames | |
## Process Frames | |
data, frame_inds = spatial_temporal_view_decomposition( | |
filename, | |
self.sample_types, | |
self.samplers, | |
self.phase == "train", | |
self.augment and (self.phase == "train"), | |
) | |
for k, v in data.items(): | |
if k == 'technical' or k == 'aesthetic': | |
data[k] = ((v.permute(1, 2, 3, 0) - self.mean) / self.std).permute( | |
3, 0, 1, 2 | |
) | |
elif k == 'semantic' : | |
data[k] = ((v.permute(1, 2, 3, 0) - self.mean_semantic) / self.std_semantic).permute( | |
3, 0, 1, 2 | |
) | |
data["num_clips"] = {} | |
for stype, sopt in self.sample_types.items(): | |
data["num_clips"][stype] = sopt["num_clips"] | |
data["frame_inds"] = frame_inds | |
data["gt_label"] = label | |
data["name"] = filename # osp.basename(video_info["filename"]) | |
except: | |
# exception flow | |
return {"name": filename} | |
return data | |
def __len__(self): | |
return len(self.video_infos) | |