Spaces:
Sleeping
Sleeping
""" | |
# Copyright 2020 Adobe | |
# All Rights Reserved. | |
# NOTICE: Adobe permits you to use, modify, and distribute this file in | |
# accordance with the terms of the Adobe license agreement accompanying | |
# it. | |
""" | |
import torch.utils.data as data | |
import os, glob, platform | |
import numpy as np | |
import cv2 | |
import torch | |
from src.dataset.image_translation.data_preparation import vis_landmark_on_img, vis_landmark_on_img98, vis_landmark_on_img74 | |
from torch.utils.data.dataloader import default_collate | |
from thirdparty.AdaptiveWingLoss.utils.utils import get_preds_fromhm | |
from scipy.io import wavfile as wav | |
from scipy.signal import stft | |
class image_translation_raw_dataset(data.Dataset): | |
def __init__(self, num_frames=16): | |
if platform.release() == '4.4.0-83-generic': # stargazer | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: # gypsum | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation/raw_fl3d' # raw vox with 1 per vid | |
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4' | |
self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.num_random_frames = num_frames + 1 | |
print(os.name, len(self.fls_filenames)) | |
def __len__(self): | |
return len(self.fls_filenames) | |
def __getitem__(self, item): | |
fls_filename = self.fls_filenames[item] | |
# load landmark file | |
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename)) | |
# load mp4 file | |
# ================= raw VOX version ================================ | |
mp4_filename = fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2][:-3] | |
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
# print('============================\nvideo_dir : ' + video_dir, item) | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
# skip first several frames due to landmark extraction | |
start_idx = (fls[0, 0]).astype(int) | |
for _ in range(start_idx): | |
ret, img_video = video.read() | |
# save video and landmark in parallel | |
frames = [] | |
random_frame_indices = np.random.permutation(fls.shape[0]-2)[0:self.num_random_frames] | |
for j in range(int(fls.shape[0])): | |
ret, img_video = video.read() | |
if(j in random_frame_indices): | |
img_fl = np.ones(shape=(224, 224, 3)) * 255 | |
idx = fls[j, 0] | |
fl = fls[j, 1:].astype(int) | |
img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3))) | |
frame = np.concatenate((img_fl, img_video), axis=2) | |
frame = cv2.resize(frame, (256, 256)) # 256 x 256 6 | |
frames.append(frame) | |
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6 | |
image_in = np.concatenate([frames[0:-1, :, :, 0:3], frames[1:, :, :, 3:6]], axis=3) | |
image_out = frames[0:-1, :, :, 3:6] | |
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3) | |
return image_in, image_out | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_raw74_dataset(data.Dataset): | |
def __init__(self, num_frames=16): | |
if platform.release() == '4.4.0-83-generic': # stargazer | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: # gypsum | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation/raw_fl3d' # raw vox with 1 per vid | |
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4' | |
self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.num_random_frames = num_frames + 1 | |
print(os.name, len(self.fls_filenames)) | |
def __len__(self): | |
return len(self.fls_filenames) | |
def __getitem__(self, item): | |
fls_filename = self.fls_filenames[item] | |
# load landmark file | |
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename)) | |
# load mp4 file | |
# ================= raw VOX version ================================ | |
mp4_filename = fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2][:-3] | |
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
# print('============================\nvideo_dir : ' + video_dir, item) | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
# skip first several frames due to landmark extraction | |
start_idx = (fls[0, 0]).astype(int) | |
for _ in range(start_idx): | |
ret, img_video = video.read() | |
# save video and landmark in parallel | |
frames = [] | |
fan_predict_landmarks = [] | |
random_frame_indices = np.random.permutation(fls.shape[0]-2)[0:self.num_random_frames] | |
for j in range(int(fls.shape[0])): | |
ret, img_video = video.read() | |
if(j in random_frame_indices): | |
fl = fls[j, 1:] / 224. * 256. | |
fan_predict_landmarks.append(np.reshape(fl, (68, 3))) | |
img_video = cv2.resize(img_video, (256, 256)) | |
frames.append(img_video.transpose((2, 0, 1))) | |
fan_predict_landmarks = np.stack(fan_predict_landmarks, axis=0) | |
frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 | |
image_in = frames[1:, :, :] | |
image_out = frames[0:-1, :, :] # N x 3 x 256 x 256 | |
return image_in, image_out, fan_predict_landmarks[0:-1] | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_raw_test_dataset(data.Dataset): | |
def __init__(self, num_frames=16): | |
if platform.release() == '4.4.0-83-generic': | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation/raw_fl3d' | |
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4' | |
self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.num_random_frames = num_frames + 1 | |
print(os.name, len(self.fls_filenames)) | |
def __len__(self): | |
return len(self.fls_filenames) | |
def __getitem__(self, item): | |
fls_filename = self.fls_filenames[item] | |
# load landmark file | |
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename)) | |
from scipy.signal import savgol_filter | |
fls = savgol_filter(fls, 11, 3, axis=0) | |
# load random face | |
random_fls_filename = self.fls_filenames[max(item - 1, 0)] | |
mp4_filename = random_fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2][:-3] | |
random_video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
print('============================\nvideo_dir : ' + random_video_dir, item) | |
random_video = cv2.VideoCapture(random_video_dir) | |
if (random_video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
_, random_face = random_video.read() | |
# load mp4 file | |
# ================= raw VOX version ================================ | |
mp4_filename = fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2][:-3] | |
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
# print('============================\nvideo_dir : ' + video_dir, item) | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
# skip first several frames due to landmark extraction | |
start_idx = (fls[0, 0]).astype(int) | |
for _ in range(start_idx): | |
ret, img_video = video.read() | |
# save video and landmark in parallel | |
frames = [] | |
for j in range(int(fls.shape[0])-2): | |
ret, img_video = video.read() | |
img_fl = np.ones(shape=(224, 224, 3)) * 255 | |
idx = fls[j, 0] | |
fl = fls[j, 1:].astype(int) | |
img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3))) | |
# print(img_fl.shape, random_face.shape, img_video.shape) | |
frame = np.concatenate((img_fl, random_face, img_video), axis=2) | |
frame = cv2.resize(frame, (256, 256)) # 256 x 256 6 | |
frames.append(frame) | |
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6 | |
image_in = frames[:, :, :, 0:6] | |
image_out = frames[:, :, :, 6:9] | |
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3) | |
return image_in, image_out | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_preprocessed_dataset(data.Dataset): | |
def __init__(self, num_frames=16): | |
if platform.release() == '4.4.0-83-generic': | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation/raw_fl3d' # first order | |
self.mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4' | |
self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.num_random_frames = num_frames + 1 | |
self.fps_scale = 2.5 | |
print(os.name, len(self.fls_filenames)) | |
def __len__(self): | |
return len(self.fls_filenames) | |
def __getitem__(self, item): | |
fls_filename = self.fls_filenames[item] | |
# load landmark file | |
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename)) | |
# # ================= preprocessed VOX version ================================ | |
video_dir = os.path.join(self.mp4_dir, fls_filename[10:-7]+'.mp4') | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
# skip first several frames due to landmark extraction | |
start_idx = (fls[0, 0] // self.fps_scale).astype(int) | |
for _ in range(start_idx): | |
ret, img_video = video.read() | |
# save video and landmark in parallel | |
frames = [] | |
random_frame_indices = np.random.permutation(int(fls.shape[0]//self.fps_scale)-2)[0:self.num_random_frames] | |
for j in range(int(fls.shape[0]//self.fps_scale)): | |
ret, img_video = video.read() | |
if(j in random_frame_indices): | |
img_fl = np.ones(shape=(256, 256, 3)) * 255 | |
idx = fls[int(j*self.fps_scale), 0] | |
fl = fls[int(j*self.fps_scale), 1:].astype(int) | |
img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3))) | |
frame = np.concatenate((img_fl, img_video), axis=2) | |
frames.append(frame) | |
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6 | |
image_in = np.concatenate([frames[0:-1, :, :, 0:3], frames[1:, :, :, 3:6]], axis=3) | |
image_out = frames[0:-1, :, :, 3:6] | |
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3) | |
return image_in, image_out | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_preprocessed_test_dataset(data.Dataset): | |
def __init__(self, num_frames=16): | |
if platform.release() == '4.4.0-83-generic': | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: | |
# self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_imagetranslation/raw_fl3d' | |
# self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4' | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation/raw_fl3d' | |
self.mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4' | |
self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.num_random_frames = num_frames + 1 | |
self.fps_scale = 2.5 | |
print(os.name, len(self.fls_filenames)) | |
def __len__(self): | |
return len(self.fls_filenames) | |
def __getitem__(self, item): | |
fls_filename = self.fls_filenames[item] | |
# load landmark file | |
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename)) | |
from scipy.signal import savgol_filter | |
fls = savgol_filter(fls, 11, 3, axis=0) | |
# load random face | |
random_fls_filename = self.fls_filenames[max(item-1, 0)] | |
# random_fls_filename = self.fls_filenames[max(item-1, 0)] | |
random_video_dir = os.path.join(self.mp4_dir, random_fls_filename[10:-7] + '.mp4') | |
random_video = cv2.VideoCapture(random_video_dir) | |
if (random_video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
_, random_face = random_video.read() | |
# # ================= preprocessed VOX version ================================ | |
video_dir = os.path.join(self.mp4_dir, fls_filename[10:-7]+'.mp4') | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
# skip first several frames due to landmark extraction | |
start_idx = (fls[0, 0] // self.fps_scale).astype(int) | |
for _ in range(start_idx): | |
ret, img_video = video.read() | |
# save video and landmark in parallel | |
frames = [] | |
for j in range(int(fls.shape[0]//self.fps_scale)): | |
ret, img_video = video.read() | |
# img_fl = np.ones(shape=(224, 224, 3)) * 255 | |
img_fl = np.ones(shape=(256, 256, 3)) * 255 | |
idx = fls[int(j*self.fps_scale), 0] | |
fl = fls[int(j*self.fps_scale), 1:].astype(int) | |
img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3))) | |
frame = np.concatenate((img_fl, random_face, img_video), axis=2) | |
# frame = cv2.resize(frame, (256, 256)) # 256 x 256 6 | |
frames.append(frame) | |
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 9 | |
image_in = frames[:, :, :, 0:6] | |
image_out = frames[:, :, :, 6:9] | |
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3) | |
return image_in, image_out | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_raw98_dataset(data.Dataset): | |
""" | |
Online landmark extraction with AWings | |
Landmark setting: 98 landmarks | |
""" | |
def __init__(self, num_frames=1): | |
if platform.release() == '4.4.0-83-generic': # stargazer | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation' | |
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4' | |
# self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.fls_filenames = np.loadtxt(os.path.join(self.src_dir, 'filename_index.txt'), dtype=str)[:, 1] | |
self.num_random_frames = num_frames + 1 | |
print(os.name, self.fls_filenames.shape) | |
def __len__(self): | |
return self.fls_filenames.shape[0] | |
def __getitem__(self, item): | |
""" | |
Get landmark alignment outside in train_pass() | |
""" | |
for i in range(5): | |
fls_filename = self.fls_filenames[(item+i)%self.fls_filenames.shape[0]] | |
# load mp4 file | |
# ================= raw VOX version ================================ | |
mp4_filename = fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2] | |
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
# print('============================\nvideo_dir : ' + video_dir, item) | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
else: | |
break | |
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# save video and landmark in parallel | |
frames = [] | |
random_frame_indices = np.random.permutation(length-2)[0:self.num_random_frames] | |
for j in range(length): | |
ret, img = video.read() | |
if(j in random_frame_indices): | |
img_video = cv2.resize(img, (256, 256)) | |
frames.append(img_video.transpose((2, 0, 1))) | |
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 | |
image_in = frames[1:, :, :] | |
image_out = frames[0:-1, :, :] # N x 3 x 256 x 256 | |
return image_in, image_out | |
def __getitem_along_with_fa__(self, item): | |
""" | |
Online get landmark alignment (deprecated) | |
(can only run under num_works=0) | |
""" | |
fls_filename = self.fls_filenames[item] | |
# load mp4 file | |
# ================= raw VOX version ================================ | |
mp4_filename = fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2] | |
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
# print('============================\nvideo_dir : ' + video_dir, item) | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# save video and landmark in parallel | |
frames = [] | |
random_frame_indices = np.random.permutation(length-2)[0:self.num_random_frames] | |
for j in range(length): | |
ret, img = video.read() | |
if(j in random_frame_indices): | |
# online landmark | |
img_video = cv2.resize(img, (256, 256)) | |
img = img_video.transpose((2, 0, 1)) / 255.0 | |
inputs = torch.tensor(img, dtype=torch.float32, requires_grad=False).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
outputs, boundary_channels = self.model(inputs) | |
pred_heatmap = outputs[-1][:, :-1, :, :][0].detach().cpu() | |
pred_landmarks, _ = get_preds_fromhm(pred_heatmap.unsqueeze(0)) | |
pred_landmarks = pred_landmarks.squeeze().numpy() * 4 | |
img_fl = np.ones(shape=(256, 256, 3)) * 255 | |
img_fl = vis_landmark_on_img98(img_fl * 255.0, pred_landmarks) # 98x2 | |
frame = np.concatenate((img_fl, img_video), axis=2) | |
frames.append(frame) | |
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6 | |
image_in = np.concatenate([frames[0:-1, :, :, 0:3], frames[1:, :, :, 3:6]], axis=3) | |
image_out = frames[0:-1, :, :, 3:6] | |
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3) | |
return image_in, image_out | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_preprocessed98_dataset(data.Dataset): | |
def __init__(self, num_frames=16): | |
if platform.release() == '4.4.0-83-generic': | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation/raw_fl3d' # first order | |
self.mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4' | |
self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.num_random_frames = num_frames + 1 | |
print(os.name, len(self.fls_filenames)) | |
def __len__(self): | |
return len(self.fls_filenames) | |
def __getitem__(self, item): | |
fls_filename = self.fls_filenames[item] | |
# # ================= preprocessed VOX version ================================ | |
video_dir = os.path.join(self.mp4_dir, fls_filename[10:-7]+'.mp4') | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# save video and landmark in parallel | |
frames = [] | |
random_frame_indices = np.random.permutation(length-2)[0:self.num_random_frames] | |
for j in range(length): | |
ret, img_video = video.read() | |
if(j in random_frame_indices): | |
img_video = cv2.resize(img_video, (256, 256)) | |
frames.append(img_video.transpose((2, 0, 1))) | |
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6 | |
image_in = frames[1:, :, :] | |
image_out = frames[0:-1, :, :] # N x 3 x 256 x 256 | |
return image_in, image_out | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_raw98_test_dataset(data.Dataset): | |
def __init__(self, num_frames=16): | |
if platform.release() == '4.4.0-83-generic': | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation' | |
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4' | |
# self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.fls_filenames = np.loadtxt(os.path.join(self.src_dir, 'filename_index.txt'), dtype=str)[:, 1] | |
self.num_random_frames = num_frames + 1 | |
print(os.name, len(self.fls_filenames)) | |
def __len__(self): | |
return len(self.fls_filenames) | |
def __getitem__(self, item): | |
fls_filename = self.fls_filenames[item] | |
# load random face | |
random_fls_filename = self.fls_filenames[max(item - 10, 0)] | |
mp4_filename = random_fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2] | |
random_video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
print('============================\nvideo_dir : ' + random_video_dir, item) | |
random_video = cv2.VideoCapture(random_video_dir) | |
if (random_video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
_, random_face = random_video.read() | |
random_face = cv2.resize(random_face, (256, 256)) | |
# load mp4 file | |
# ================= raw VOX version ================================ | |
mp4_filename = fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2] | |
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
# print('============================\nvideo_dir : ' + video_dir, item) | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# save video and landmark in parallel | |
frames = [] | |
for j in range(length): | |
ret, img_video = video.read() | |
img_video = cv2.resize(img_video, (256, 256)) | |
frame = np.concatenate((random_face, img_video), axis=2) | |
frames.append(frame.transpose((2, 0, 1))) | |
frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 # N x 256 x 256 x 9 | |
image_in = frames[:, 0:3] | |
image_out = frames[:, 3:6] | |
return image_in, image_out | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_preprocessed98_test_dataset(data.Dataset): | |
def __init__(self, num_frames=16): | |
if platform.release() == '4.4.0-83-generic': | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: | |
# self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_imagetranslation/raw_fl3d' | |
# self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4' | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation/raw_fl3d' | |
self.mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4' | |
self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.num_random_frames = num_frames + 1 | |
print(os.name, len(self.fls_filenames)) | |
def __len__(self): | |
return len(self.fls_filenames) | |
def __getitem__(self, item): | |
fls_filename = self.fls_filenames[item] | |
# load random face | |
random_fls_filename = self.fls_filenames[max(item-10, 0)] | |
# random_fls_filename = self.fls_filenames[max(item-1, 0)] | |
random_video_dir = os.path.join(self.mp4_dir, random_fls_filename[10:-7] + '.mp4') | |
random_video = cv2.VideoCapture(random_video_dir) | |
if (random_video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
_, random_face = random_video.read() | |
# # ================= preprocessed VOX version ================================ | |
video_dir = os.path.join(self.mp4_dir, fls_filename[10:-7]+'.mp4') | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# save video and landmark in parallel | |
frames = [] | |
for j in range(length): | |
ret, img_video = video.read() | |
img_video = cv2.resize(img_video, (256, 256)) | |
frame = np.concatenate((random_face, img_video), axis=2) | |
frames.append(frame.transpose((2, 0, 1))) | |
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 9 | |
image_in = frames[:, 0:3] | |
image_out = frames[:, 3:6] | |
return image_in, image_out | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_raw98_with_audio_dataset(data.Dataset): | |
""" | |
Online landmark extraction with AWings | |
Landmark setting: 98 landmarks | |
""" | |
def __init__(self, num_frames=1): | |
if platform.release() == '4.4.0-83-generic': # stargazer | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation' | |
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4' | |
# self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.fls_filenames = np.loadtxt(os.path.join(self.src_dir, 'filename_index.txt'), dtype=str)[:, 1] | |
self.num_random_frames = num_frames + 1 | |
print(os.name, self.fls_filenames.shape) | |
def __len__(self): | |
return self.fls_filenames.shape[0] | |
def __getitem__(self, item): | |
""" | |
Get landmark alignment outside in train_pass() | |
""" | |
for i in range(5): | |
fls_filename = self.fls_filenames[(item+i)%self.fls_filenames.shape[0]] | |
# load mp4 file | |
# ================= raw VOX version ================================ | |
mp4_filename = fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2] | |
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
# print('============================\nvideo_dir : ' + video_dir, item) | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
else: | |
break | |
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# save video and landmark in parallel | |
frames = [] | |
random_frame_indices = np.random.permutation(max(1, length-12))[0:self.num_random_frames] | |
random_frame_indices = [item + 5 for item in random_frame_indices] | |
for j in range(length): | |
ret, img = video.read() | |
if(j in random_frame_indices): | |
img_video = cv2.resize(img, (256, 256)) | |
frames.append(img_video.transpose((2, 0, 1))) | |
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 | |
image_in = frames[1:, :, :] | |
image_out = frames[0:-1, :, :] # N x 3 x 256 x 256 | |
# audio | |
os.system('ffmpeg -y -loglevel error -i {} -vn -ar 16000 -ac 1 {}'.format( | |
video_dir, video_dir.replace('.mp4', '.wav') | |
)) | |
sample_rate, samples = wav.read(video_dir.replace('.mp4', '.wav')) | |
assert (sample_rate == 16000) | |
if (len(samples.shape) > 1): | |
samples = samples[:, 0] # pick mono | |
# 1 frame = 1/25 * 16k = 640 samples => windowsize=320, overlap=160 | |
# 80 overlap => 200 / 1 sec, 8 / 1 frame | |
f, t, Zxx = stft(samples, fs=sample_rate, nperseg=640, noverlap=560) | |
stft_abs = np.log(np.abs(Zxx) ** 2 + 1e-10) | |
stft_abs = stft_abs / np.max(stft_abs) | |
os.remove(video_dir.replace('.mp4', '.wav')) | |
# we want 0.2s before, 5 frames, 40 dims | |
# and 0.2s after (may remove later) | |
audio_in = [] | |
for item in random_frame_indices: | |
sel_audio_clip = stft_abs[:, (item-5)*8:(item+5)*8] | |
assert sel_audio_clip.shape[1] == 80 | |
audio_in.append(np.expand_dims(cv2.resize(sel_audio_clip, (256, 256)), axis=0)) | |
audio_in = np.stack(audio_in[0:-1], axis=0).astype(np.float32) | |
# image_in = np.concatenate([image_in, audio_in], axis=1) | |
return image_in, image_out, audio_in | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
class image_translation_raw98_with_audio_test_dataset(data.Dataset): | |
""" | |
Online landmark extraction with AWings | |
Landmark setting: 98 landmarks | |
""" | |
def __init__(self, num_frames=1): | |
if platform.release() == '4.4.0-83-generic': # stargazer | |
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation' | |
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4' | |
else: | |
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation' | |
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4' | |
# self.fls_filenames = glob.glob1(self.src_dir, '*') | |
self.fls_filenames = np.loadtxt(os.path.join(self.src_dir, 'filename_index.txt'), dtype=str)[:, 1] | |
self.num_random_frames = num_frames + 1 | |
print(os.name, self.fls_filenames.shape) | |
def __len__(self): | |
return self.fls_filenames.shape[0] | |
def __getitem__(self, item): | |
""" | |
Get landmark alignment outside in train_pass() | |
""" | |
# load random face | |
random_fls_filename = self.fls_filenames[max(item - 10, 0)] | |
mp4_filename = random_fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2] | |
random_video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
print('============================\nvideo_dir : ' + random_video_dir, item) | |
random_video = cv2.VideoCapture(random_video_dir) | |
if (random_video.isOpened() == False): | |
print('Unable to open video file') | |
exit(0) | |
_, random_face = random_video.read() | |
random_face = cv2.resize(random_face, (256, 256)) | |
fls_filename = self.fls_filenames[item] | |
# load mp4 file | |
# ================= raw VOX version ================================ | |
mp4_filename = fls_filename[:-4].split('_x_') | |
mp4_id = mp4_filename[0].split('_')[-1] | |
mp4_vname = mp4_filename[1] | |
mp4_vid = mp4_filename[2] | |
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4') | |
# print('============================\nvideo_dir : ' + video_dir, item) | |
# ====================================================================== | |
video = cv2.VideoCapture(video_dir) | |
if (video.isOpened() == False): | |
print('Unable to open video file') | |
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# save video and landmark in parallel | |
frames = [] | |
for j in range(5, length-5): | |
ret, img_video = video.read() | |
img_video = cv2.resize(img_video, (256, 256)) | |
frame = np.concatenate((random_face, img_video), axis=2) | |
frames.append(frame.transpose((2, 0, 1))) | |
frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 # N x 256 x 256 x 9 | |
image_in = frames[:, 0:3] | |
image_out = frames[:, 3:6] | |
# audio | |
os.system('ffmpeg -y -loglevel error -i {} -vn -ar 16000 -ac 1 {}'.format( | |
video_dir, video_dir.replace('.mp4', '.wav') | |
)) | |
sample_rate, samples = wav.read(video_dir.replace('.mp4', '.wav')) | |
assert (sample_rate == 16000) | |
if (len(samples.shape) > 1): | |
samples = samples[:, 0] # pick mono | |
# 1 frame = 1/25 * 16k = 640 samples => windowsize=320, overlap=160 | |
# 80 overlap => 200 / 1 sec, 8 / 1 frame | |
f, t, Zxx = stft(samples, fs=sample_rate, nperseg=640, noverlap=560) | |
stft_abs = np.log(np.abs(Zxx) ** 2 + 1e-10) | |
stft_abs = stft_abs / np.max(stft_abs) | |
os.remove(video_dir.replace('.mp4', '.wav')) | |
# we want 0.2s before, 5 frames, 40 dims | |
# and 0.2s after (may remove later) | |
audio_in = [] | |
for item in range(5, length-5): | |
sel_audio_clip = stft_abs[:, (item-5)*8:(item+5)*8] | |
assert sel_audio_clip.shape[1] == 80 | |
audio_in.append(np.expand_dims(cv2.resize(sel_audio_clip, (256, 256)), axis=0)) | |
audio_in = np.stack(audio_in, axis=0).astype(np.float32) | |
# image_in = np.concatenate([image_in, audio_in], axis=1) | |
return image_in, image_out, audio_in | |
def my_collate(self, batch): | |
batch = filter(lambda x:x is not None, batch) | |
return default_collate(batch) | |
if __name__ == '__main__': | |
d = image_translation_raw_dataset() | |
d_loader = torch.utils.data.DataLoader(d, batch_size=4, shuffle=True) | |
print(len(d)) | |
for i, batch in enumerate(d_loader): | |
print(i, batch[0].shape, batch[1].shape) |