|
|
|
|
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
class VaceImageProcessor(object): |
|
def __init__(self, downsample=None, seq_len=None): |
|
self.downsample = downsample |
|
self.seq_len = seq_len |
|
|
|
def _pillow_convert(self, image, cvt_type='RGB'): |
|
if image.mode != cvt_type: |
|
if image.mode == 'P': |
|
image = image.convert(f'{cvt_type}A') |
|
if image.mode == f'{cvt_type}A': |
|
bg = Image.new(cvt_type, |
|
size=(image.width, image.height), |
|
color=(255, 255, 255)) |
|
bg.paste(image, (0, 0), mask=image) |
|
image = bg |
|
else: |
|
image = image.convert(cvt_type) |
|
return image |
|
|
|
def _load_image(self, img_path): |
|
if img_path is None or img_path == '': |
|
return None |
|
img = Image.open(img_path) |
|
img = self._pillow_convert(img) |
|
return img |
|
|
|
def _resize_crop(self, img, oh, ow, normalize=True): |
|
""" |
|
Resize, center crop, convert to tensor, and normalize. |
|
""" |
|
|
|
iw, ih = img.size |
|
if iw != ow or ih != oh: |
|
|
|
scale = max(ow / iw, oh / ih) |
|
img = img.resize( |
|
(round(scale * iw), round(scale * ih)), |
|
resample=Image.Resampling.LANCZOS |
|
) |
|
assert img.width >= ow and img.height >= oh |
|
|
|
|
|
x1 = (img.width - ow) // 2 |
|
y1 = (img.height - oh) // 2 |
|
img = img.crop((x1, y1, x1 + ow, y1 + oh)) |
|
|
|
|
|
if normalize: |
|
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) |
|
return img |
|
|
|
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): |
|
return self._resize_crop(img, oh, ow, normalize) |
|
|
|
def load_image(self, data_key, **kwargs): |
|
return self.load_image_batch(data_key, **kwargs) |
|
|
|
def load_image_pair(self, data_key, data_key2, **kwargs): |
|
return self.load_image_batch(data_key, data_key2, **kwargs) |
|
|
|
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs): |
|
seq_len = self.seq_len if seq_len is None else seq_len |
|
imgs = [] |
|
for data_key in data_key_batch: |
|
img = self._load_image(data_key) |
|
imgs.append(img) |
|
w, h = imgs[0].size |
|
dh, dw = self.downsample[1:] |
|
|
|
|
|
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) |
|
oh = int(h * scale) // dh * dh |
|
ow = int(w * scale) // dw * dw |
|
assert (oh // dh) * (ow // dw) <= seq_len |
|
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] |
|
return *imgs, (oh, ow) |
|
|
|
|
|
class VaceVideoProcessor(object): |
|
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs): |
|
self.downsample = downsample |
|
self.min_area = min_area |
|
self.max_area = max_area |
|
self.min_fps = min_fps |
|
self.max_fps = max_fps |
|
self.zero_start = zero_start |
|
self.keep_last = keep_last |
|
self.seq_len = seq_len |
|
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) |
|
|
|
@staticmethod |
|
def resize_crop(video: torch.Tensor, oh: int, ow: int): |
|
""" |
|
Resize, center crop and normalize for decord loaded video (torch.Tensor type) |
|
|
|
Parameters: |
|
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) |
|
oh - target height (int) |
|
ow - target width (int) |
|
|
|
Returns: |
|
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) |
|
|
|
Raises: |
|
""" |
|
|
|
video = video.permute(0, 3, 1, 2) |
|
|
|
|
|
ih, iw = video.shape[2:] |
|
if ih != oh or iw != ow: |
|
|
|
scale = max(ow / iw, oh / ih) |
|
video = F.interpolate( |
|
video, |
|
size=(round(scale * ih), round(scale * iw)), |
|
mode='bicubic', |
|
antialias=True |
|
) |
|
assert video.size(3) >= ow and video.size(2) >= oh |
|
|
|
|
|
x1 = (video.size(3) - ow) // 2 |
|
y1 = (video.size(2) - oh) // 2 |
|
video = video[:, :, y1:y1 + oh, x1:x1 + ow] |
|
|
|
|
|
video = video.transpose(0, 1).float().div_(127.5).sub_(1.) |
|
return video |
|
|
|
def _video_preprocess(self, video, oh, ow): |
|
return self.resize_crop(video, oh, ow) |
|
|
|
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): |
|
target_fps = min(fps, self.max_fps) |
|
duration = frame_timestamps[-1].mean() |
|
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box |
|
h, w = y2 - y1, x2 - x1 |
|
ratio = h / w |
|
df, dh, dw = self.downsample |
|
|
|
|
|
min_area_z = self.min_area / (dh * dw) |
|
max_area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) |
|
|
|
|
|
rand_area_z = np.square(np.power(2, rng.uniform( |
|
np.log2(np.sqrt(min_area_z)), |
|
np.log2(np.sqrt(max_area_z)) |
|
))) |
|
of = min( |
|
(int(duration * target_fps) - 1) // df + 1, |
|
int(self.seq_len / rand_area_z) |
|
) |
|
|
|
|
|
target_area_z = min(max_area_z, int(self.seq_len / of)) |
|
oh = round(np.sqrt(target_area_z * ratio)) |
|
ow = int(target_area_z / oh) |
|
of = (of - 1) * df + 1 |
|
oh *= dh |
|
ow *= dw |
|
|
|
|
|
target_duration = of / target_fps |
|
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration) |
|
timestamps = np.linspace(begin, begin + target_duration, of) |
|
frame_ids = np.argmax(np.logical_and( |
|
timestamps[:, None] >= frame_timestamps[None, :, 0], |
|
timestamps[:, None] < frame_timestamps[None, :, 1] |
|
), axis=1).tolist() |
|
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps |
|
|
|
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng): |
|
duration = frame_timestamps[-1].mean() |
|
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box |
|
h, w = y2 - y1, x2 - x1 |
|
ratio = h / w |
|
df, dh, dw = self.downsample |
|
|
|
|
|
min_area_z = self.min_area / (dh * dw) |
|
max_area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) |
|
|
|
|
|
rand_area_z = np.square(np.power(2, rng.uniform( |
|
np.log2(np.sqrt(min_area_z)), |
|
np.log2(np.sqrt(max_area_z)) |
|
))) |
|
|
|
of = min( |
|
(len(frame_timestamps) - 1) // df + 1, |
|
int(self.seq_len / rand_area_z) |
|
) |
|
|
|
|
|
target_area_z = min(max_area_z, int(self.seq_len / of)) |
|
oh = round(np.sqrt(target_area_z * ratio)) |
|
ow = int(target_area_z / oh) |
|
of = (of - 1) * df + 1 |
|
oh *= dh |
|
ow *= dw |
|
|
|
|
|
target_duration = duration |
|
target_fps = of / target_duration |
|
timestamps = np.linspace(0., target_duration, of) |
|
frame_ids = np.argmax(np.logical_and( |
|
timestamps[:, None] >= frame_timestamps[None, :, 0], |
|
timestamps[:, None] <= frame_timestamps[None, :, 1] |
|
), axis=1).tolist() |
|
|
|
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps |
|
|
|
|
|
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): |
|
if self.keep_last: |
|
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng) |
|
else: |
|
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng) |
|
|
|
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): |
|
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) |
|
|
|
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): |
|
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) |
|
|
|
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs): |
|
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) |
|
|
|
import decord |
|
decord.bridge.set_bridge('torch') |
|
readers = [] |
|
for data_k in data_key_batch: |
|
reader = decord.VideoReader(data_k) |
|
readers.append(reader) |
|
|
|
fps = readers[0].get_avg_fps() |
|
length = min([len(r) for r in readers]) |
|
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] |
|
frame_timestamps = np.array(frame_timestamps, dtype=np.float32) |
|
h, w = readers[0].next().shape[:2] |
|
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng) |
|
|
|
|
|
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] |
|
videos = [self._video_preprocess(video, oh, ow) for video in videos] |
|
return *videos, frame_ids, (oh, ow), fps |
|
|
|
|
|
|
|
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): |
|
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): |
|
if sub_src_video is None and sub_src_mask is None: |
|
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) |
|
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) |
|
for i, ref_images in enumerate(src_ref_images): |
|
if ref_images is not None: |
|
for j, ref_img in enumerate(ref_images): |
|
if ref_img is not None and ref_img.shape[-2:] != image_size: |
|
canvas_height, canvas_width = image_size |
|
ref_height, ref_width = ref_img.shape[-2:] |
|
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) |
|
scale = min(canvas_height / ref_height, canvas_width / ref_width) |
|
new_height = int(ref_height * scale) |
|
new_width = int(ref_width * scale) |
|
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) |
|
top = (canvas_height - new_height) // 2 |
|
left = (canvas_width - new_width) // 2 |
|
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image |
|
src_ref_images[i][j] = white_canvas |
|
return src_video, src_mask, src_ref_images |
|
|