marlenezw's picture
changing face alignment and removing its docker file.
22257c4
"""
# Copyright 2020 Adobe
# All Rights Reserved.
# NOTICE: Adobe permits you to use, modify, and distribute this file in
# accordance with the terms of the Adobe license agreement accompanying
# it.
"""
import torch.utils.data as data
import os, glob, platform
import numpy as np
import cv2
import torch
from src.dataset.image_translation.data_preparation import vis_landmark_on_img, vis_landmark_on_img98, vis_landmark_on_img74
from torch.utils.data.dataloader import default_collate
from thirdparty.AdaptiveWingLoss.utils.utils import get_preds_fromhm
from scipy.io import wavfile as wav
from scipy.signal import stft
class image_translation_raw_dataset(data.Dataset):
def __init__(self, num_frames=16):
if platform.release() == '4.4.0-83-generic': # stargazer
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else: # gypsum
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation/raw_fl3d' # raw vox with 1 per vid
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
self.fls_filenames = glob.glob1(self.src_dir, '*')
self.num_random_frames = num_frames + 1
print(os.name, len(self.fls_filenames))
def __len__(self):
return len(self.fls_filenames)
def __getitem__(self, item):
fls_filename = self.fls_filenames[item]
# load landmark file
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename))
# load mp4 file
# ================= raw VOX version ================================
mp4_filename = fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2][:-3]
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
# print('============================\nvideo_dir : ' + video_dir, item)
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
exit(0)
# skip first several frames due to landmark extraction
start_idx = (fls[0, 0]).astype(int)
for _ in range(start_idx):
ret, img_video = video.read()
# save video and landmark in parallel
frames = []
random_frame_indices = np.random.permutation(fls.shape[0]-2)[0:self.num_random_frames]
for j in range(int(fls.shape[0])):
ret, img_video = video.read()
if(j in random_frame_indices):
img_fl = np.ones(shape=(224, 224, 3)) * 255
idx = fls[j, 0]
fl = fls[j, 1:].astype(int)
img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3)))
frame = np.concatenate((img_fl, img_video), axis=2)
frame = cv2.resize(frame, (256, 256)) # 256 x 256 6
frames.append(frame)
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6
image_in = np.concatenate([frames[0:-1, :, :, 0:3], frames[1:, :, :, 3:6]], axis=3)
image_out = frames[0:-1, :, :, 3:6]
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3)
return image_in, image_out
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_raw74_dataset(data.Dataset):
def __init__(self, num_frames=16):
if platform.release() == '4.4.0-83-generic': # stargazer
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else: # gypsum
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation/raw_fl3d' # raw vox with 1 per vid
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
self.fls_filenames = glob.glob1(self.src_dir, '*')
self.num_random_frames = num_frames + 1
print(os.name, len(self.fls_filenames))
def __len__(self):
return len(self.fls_filenames)
def __getitem__(self, item):
fls_filename = self.fls_filenames[item]
# load landmark file
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename))
# load mp4 file
# ================= raw VOX version ================================
mp4_filename = fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2][:-3]
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
# print('============================\nvideo_dir : ' + video_dir, item)
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
exit(0)
# skip first several frames due to landmark extraction
start_idx = (fls[0, 0]).astype(int)
for _ in range(start_idx):
ret, img_video = video.read()
# save video and landmark in parallel
frames = []
fan_predict_landmarks = []
random_frame_indices = np.random.permutation(fls.shape[0]-2)[0:self.num_random_frames]
for j in range(int(fls.shape[0])):
ret, img_video = video.read()
if(j in random_frame_indices):
fl = fls[j, 1:] / 224. * 256.
fan_predict_landmarks.append(np.reshape(fl, (68, 3)))
img_video = cv2.resize(img_video, (256, 256))
frames.append(img_video.transpose((2, 0, 1)))
fan_predict_landmarks = np.stack(fan_predict_landmarks, axis=0)
frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
image_in = frames[1:, :, :]
image_out = frames[0:-1, :, :] # N x 3 x 256 x 256
return image_in, image_out, fan_predict_landmarks[0:-1]
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_raw_test_dataset(data.Dataset):
def __init__(self, num_frames=16):
if platform.release() == '4.4.0-83-generic':
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else:
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation/raw_fl3d'
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
self.fls_filenames = glob.glob1(self.src_dir, '*')
self.num_random_frames = num_frames + 1
print(os.name, len(self.fls_filenames))
def __len__(self):
return len(self.fls_filenames)
def __getitem__(self, item):
fls_filename = self.fls_filenames[item]
# load landmark file
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename))
from scipy.signal import savgol_filter
fls = savgol_filter(fls, 11, 3, axis=0)
# load random face
random_fls_filename = self.fls_filenames[max(item - 1, 0)]
mp4_filename = random_fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2][:-3]
random_video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
print('============================\nvideo_dir : ' + random_video_dir, item)
random_video = cv2.VideoCapture(random_video_dir)
if (random_video.isOpened() == False):
print('Unable to open video file')
exit(0)
_, random_face = random_video.read()
# load mp4 file
# ================= raw VOX version ================================
mp4_filename = fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2][:-3]
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
# print('============================\nvideo_dir : ' + video_dir, item)
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
exit(0)
# skip first several frames due to landmark extraction
start_idx = (fls[0, 0]).astype(int)
for _ in range(start_idx):
ret, img_video = video.read()
# save video and landmark in parallel
frames = []
for j in range(int(fls.shape[0])-2):
ret, img_video = video.read()
img_fl = np.ones(shape=(224, 224, 3)) * 255
idx = fls[j, 0]
fl = fls[j, 1:].astype(int)
img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3)))
# print(img_fl.shape, random_face.shape, img_video.shape)
frame = np.concatenate((img_fl, random_face, img_video), axis=2)
frame = cv2.resize(frame, (256, 256)) # 256 x 256 6
frames.append(frame)
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6
image_in = frames[:, :, :, 0:6]
image_out = frames[:, :, :, 6:9]
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3)
return image_in, image_out
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_preprocessed_dataset(data.Dataset):
def __init__(self, num_frames=16):
if platform.release() == '4.4.0-83-generic':
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else:
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation/raw_fl3d' # first order
self.mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4'
self.fls_filenames = glob.glob1(self.src_dir, '*')
self.num_random_frames = num_frames + 1
self.fps_scale = 2.5
print(os.name, len(self.fls_filenames))
def __len__(self):
return len(self.fls_filenames)
def __getitem__(self, item):
fls_filename = self.fls_filenames[item]
# load landmark file
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename))
# # ================= preprocessed VOX version ================================
video_dir = os.path.join(self.mp4_dir, fls_filename[10:-7]+'.mp4')
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
exit(0)
# skip first several frames due to landmark extraction
start_idx = (fls[0, 0] // self.fps_scale).astype(int)
for _ in range(start_idx):
ret, img_video = video.read()
# save video and landmark in parallel
frames = []
random_frame_indices = np.random.permutation(int(fls.shape[0]//self.fps_scale)-2)[0:self.num_random_frames]
for j in range(int(fls.shape[0]//self.fps_scale)):
ret, img_video = video.read()
if(j in random_frame_indices):
img_fl = np.ones(shape=(256, 256, 3)) * 255
idx = fls[int(j*self.fps_scale), 0]
fl = fls[int(j*self.fps_scale), 1:].astype(int)
img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3)))
frame = np.concatenate((img_fl, img_video), axis=2)
frames.append(frame)
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6
image_in = np.concatenate([frames[0:-1, :, :, 0:3], frames[1:, :, :, 3:6]], axis=3)
image_out = frames[0:-1, :, :, 3:6]
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3)
return image_in, image_out
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_preprocessed_test_dataset(data.Dataset):
def __init__(self, num_frames=16):
if platform.release() == '4.4.0-83-generic':
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else:
# self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_imagetranslation/raw_fl3d'
# self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation/raw_fl3d'
self.mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4'
self.fls_filenames = glob.glob1(self.src_dir, '*')
self.num_random_frames = num_frames + 1
self.fps_scale = 2.5
print(os.name, len(self.fls_filenames))
def __len__(self):
return len(self.fls_filenames)
def __getitem__(self, item):
fls_filename = self.fls_filenames[item]
# load landmark file
fls = np.loadtxt(os.path.join(self.src_dir, fls_filename))
from scipy.signal import savgol_filter
fls = savgol_filter(fls, 11, 3, axis=0)
# load random face
random_fls_filename = self.fls_filenames[max(item-1, 0)]
# random_fls_filename = self.fls_filenames[max(item-1, 0)]
random_video_dir = os.path.join(self.mp4_dir, random_fls_filename[10:-7] + '.mp4')
random_video = cv2.VideoCapture(random_video_dir)
if (random_video.isOpened() == False):
print('Unable to open video file')
exit(0)
_, random_face = random_video.read()
# # ================= preprocessed VOX version ================================
video_dir = os.path.join(self.mp4_dir, fls_filename[10:-7]+'.mp4')
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
exit(0)
# skip first several frames due to landmark extraction
start_idx = (fls[0, 0] // self.fps_scale).astype(int)
for _ in range(start_idx):
ret, img_video = video.read()
# save video and landmark in parallel
frames = []
for j in range(int(fls.shape[0]//self.fps_scale)):
ret, img_video = video.read()
# img_fl = np.ones(shape=(224, 224, 3)) * 255
img_fl = np.ones(shape=(256, 256, 3)) * 255
idx = fls[int(j*self.fps_scale), 0]
fl = fls[int(j*self.fps_scale), 1:].astype(int)
img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3)))
frame = np.concatenate((img_fl, random_face, img_video), axis=2)
# frame = cv2.resize(frame, (256, 256)) # 256 x 256 6
frames.append(frame)
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 9
image_in = frames[:, :, :, 0:6]
image_out = frames[:, :, :, 6:9]
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3)
return image_in, image_out
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_raw98_dataset(data.Dataset):
"""
Online landmark extraction with AWings
Landmark setting: 98 landmarks
"""
def __init__(self, num_frames=1):
if platform.release() == '4.4.0-83-generic': # stargazer
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else:
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation'
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
# self.fls_filenames = glob.glob1(self.src_dir, '*')
self.fls_filenames = np.loadtxt(os.path.join(self.src_dir, 'filename_index.txt'), dtype=str)[:, 1]
self.num_random_frames = num_frames + 1
print(os.name, self.fls_filenames.shape)
def __len__(self):
return self.fls_filenames.shape[0]
def __getitem__(self, item):
"""
Get landmark alignment outside in train_pass()
"""
for i in range(5):
fls_filename = self.fls_filenames[(item+i)%self.fls_filenames.shape[0]]
# load mp4 file
# ================= raw VOX version ================================
mp4_filename = fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2]
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
# print('============================\nvideo_dir : ' + video_dir, item)
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
else:
break
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# save video and landmark in parallel
frames = []
random_frame_indices = np.random.permutation(length-2)[0:self.num_random_frames]
for j in range(length):
ret, img = video.read()
if(j in random_frame_indices):
img_video = cv2.resize(img, (256, 256))
frames.append(img_video.transpose((2, 0, 1)))
frames = np.stack(frames, axis=0).astype(np.float32)/255.0
image_in = frames[1:, :, :]
image_out = frames[0:-1, :, :] # N x 3 x 256 x 256
return image_in, image_out
def __getitem_along_with_fa__(self, item):
"""
Online get landmark alignment (deprecated)
(can only run under num_works=0)
"""
fls_filename = self.fls_filenames[item]
# load mp4 file
# ================= raw VOX version ================================
mp4_filename = fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2]
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
# print('============================\nvideo_dir : ' + video_dir, item)
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
exit(0)
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# save video and landmark in parallel
frames = []
random_frame_indices = np.random.permutation(length-2)[0:self.num_random_frames]
for j in range(length):
ret, img = video.read()
if(j in random_frame_indices):
# online landmark
img_video = cv2.resize(img, (256, 256))
img = img_video.transpose((2, 0, 1)) / 255.0
inputs = torch.tensor(img, dtype=torch.float32, requires_grad=False).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs, boundary_channels = self.model(inputs)
pred_heatmap = outputs[-1][:, :-1, :, :][0].detach().cpu()
pred_landmarks, _ = get_preds_fromhm(pred_heatmap.unsqueeze(0))
pred_landmarks = pred_landmarks.squeeze().numpy() * 4
img_fl = np.ones(shape=(256, 256, 3)) * 255
img_fl = vis_landmark_on_img98(img_fl * 255.0, pred_landmarks) # 98x2
frame = np.concatenate((img_fl, img_video), axis=2)
frames.append(frame)
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6
image_in = np.concatenate([frames[0:-1, :, :, 0:3], frames[1:, :, :, 3:6]], axis=3)
image_out = frames[0:-1, :, :, 3:6]
image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3)
return image_in, image_out
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_preprocessed98_dataset(data.Dataset):
def __init__(self, num_frames=16):
if platform.release() == '4.4.0-83-generic':
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else:
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation/raw_fl3d' # first order
self.mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4'
self.fls_filenames = glob.glob1(self.src_dir, '*')
self.num_random_frames = num_frames + 1
print(os.name, len(self.fls_filenames))
def __len__(self):
return len(self.fls_filenames)
def __getitem__(self, item):
fls_filename = self.fls_filenames[item]
# # ================= preprocessed VOX version ================================
video_dir = os.path.join(self.mp4_dir, fls_filename[10:-7]+'.mp4')
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
exit(0)
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# save video and landmark in parallel
frames = []
random_frame_indices = np.random.permutation(length-2)[0:self.num_random_frames]
for j in range(length):
ret, img_video = video.read()
if(j in random_frame_indices):
img_video = cv2.resize(img_video, (256, 256))
frames.append(img_video.transpose((2, 0, 1)))
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 6
image_in = frames[1:, :, :]
image_out = frames[0:-1, :, :] # N x 3 x 256 x 256
return image_in, image_out
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_raw98_test_dataset(data.Dataset):
def __init__(self, num_frames=16):
if platform.release() == '4.4.0-83-generic':
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else:
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation'
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
# self.fls_filenames = glob.glob1(self.src_dir, '*')
self.fls_filenames = np.loadtxt(os.path.join(self.src_dir, 'filename_index.txt'), dtype=str)[:, 1]
self.num_random_frames = num_frames + 1
print(os.name, len(self.fls_filenames))
def __len__(self):
return len(self.fls_filenames)
def __getitem__(self, item):
fls_filename = self.fls_filenames[item]
# load random face
random_fls_filename = self.fls_filenames[max(item - 10, 0)]
mp4_filename = random_fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2]
random_video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
print('============================\nvideo_dir : ' + random_video_dir, item)
random_video = cv2.VideoCapture(random_video_dir)
if (random_video.isOpened() == False):
print('Unable to open video file')
exit(0)
_, random_face = random_video.read()
random_face = cv2.resize(random_face, (256, 256))
# load mp4 file
# ================= raw VOX version ================================
mp4_filename = fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2]
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
# print('============================\nvideo_dir : ' + video_dir, item)
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
exit(0)
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# save video and landmark in parallel
frames = []
for j in range(length):
ret, img_video = video.read()
img_video = cv2.resize(img_video, (256, 256))
frame = np.concatenate((random_face, img_video), axis=2)
frames.append(frame.transpose((2, 0, 1)))
frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 # N x 256 x 256 x 9
image_in = frames[:, 0:3]
image_out = frames[:, 3:6]
return image_in, image_out
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_preprocessed98_test_dataset(data.Dataset):
def __init__(self, num_frames=16):
if platform.release() == '4.4.0-83-generic':
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else:
# self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_imagetranslation/raw_fl3d'
# self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation/raw_fl3d'
self.mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4'
self.fls_filenames = glob.glob1(self.src_dir, '*')
self.num_random_frames = num_frames + 1
print(os.name, len(self.fls_filenames))
def __len__(self):
return len(self.fls_filenames)
def __getitem__(self, item):
fls_filename = self.fls_filenames[item]
# load random face
random_fls_filename = self.fls_filenames[max(item-10, 0)]
# random_fls_filename = self.fls_filenames[max(item-1, 0)]
random_video_dir = os.path.join(self.mp4_dir, random_fls_filename[10:-7] + '.mp4')
random_video = cv2.VideoCapture(random_video_dir)
if (random_video.isOpened() == False):
print('Unable to open video file')
exit(0)
_, random_face = random_video.read()
# # ================= preprocessed VOX version ================================
video_dir = os.path.join(self.mp4_dir, fls_filename[10:-7]+'.mp4')
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
exit(0)
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# save video and landmark in parallel
frames = []
for j in range(length):
ret, img_video = video.read()
img_video = cv2.resize(img_video, (256, 256))
frame = np.concatenate((random_face, img_video), axis=2)
frames.append(frame.transpose((2, 0, 1)))
frames = np.stack(frames, axis=0).astype(np.float32)/255.0 # N x 256 x 256 x 9
image_in = frames[:, 0:3]
image_out = frames[:, 3:6]
return image_in, image_out
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_raw98_with_audio_dataset(data.Dataset):
"""
Online landmark extraction with AWings
Landmark setting: 98 landmarks
"""
def __init__(self, num_frames=1):
if platform.release() == '4.4.0-83-generic': # stargazer
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else:
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation'
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
# self.fls_filenames = glob.glob1(self.src_dir, '*')
self.fls_filenames = np.loadtxt(os.path.join(self.src_dir, 'filename_index.txt'), dtype=str)[:, 1]
self.num_random_frames = num_frames + 1
print(os.name, self.fls_filenames.shape)
def __len__(self):
return self.fls_filenames.shape[0]
def __getitem__(self, item):
"""
Get landmark alignment outside in train_pass()
"""
for i in range(5):
fls_filename = self.fls_filenames[(item+i)%self.fls_filenames.shape[0]]
# load mp4 file
# ================= raw VOX version ================================
mp4_filename = fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2]
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
# print('============================\nvideo_dir : ' + video_dir, item)
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
else:
break
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# save video and landmark in parallel
frames = []
random_frame_indices = np.random.permutation(max(1, length-12))[0:self.num_random_frames]
random_frame_indices = [item + 5 for item in random_frame_indices]
for j in range(length):
ret, img = video.read()
if(j in random_frame_indices):
img_video = cv2.resize(img, (256, 256))
frames.append(img_video.transpose((2, 0, 1)))
frames = np.stack(frames, axis=0).astype(np.float32)/255.0
image_in = frames[1:, :, :]
image_out = frames[0:-1, :, :] # N x 3 x 256 x 256
# audio
os.system('ffmpeg -y -loglevel error -i {} -vn -ar 16000 -ac 1 {}'.format(
video_dir, video_dir.replace('.mp4', '.wav')
))
sample_rate, samples = wav.read(video_dir.replace('.mp4', '.wav'))
assert (sample_rate == 16000)
if (len(samples.shape) > 1):
samples = samples[:, 0] # pick mono
# 1 frame = 1/25 * 16k = 640 samples => windowsize=320, overlap=160
# 80 overlap => 200 / 1 sec, 8 / 1 frame
f, t, Zxx = stft(samples, fs=sample_rate, nperseg=640, noverlap=560)
stft_abs = np.log(np.abs(Zxx) ** 2 + 1e-10)
stft_abs = stft_abs / np.max(stft_abs)
os.remove(video_dir.replace('.mp4', '.wav'))
# we want 0.2s before, 5 frames, 40 dims
# and 0.2s after (may remove later)
audio_in = []
for item in random_frame_indices:
sel_audio_clip = stft_abs[:, (item-5)*8:(item+5)*8]
assert sel_audio_clip.shape[1] == 80
audio_in.append(np.expand_dims(cv2.resize(sel_audio_clip, (256, 256)), axis=0))
audio_in = np.stack(audio_in[0:-1], axis=0).astype(np.float32)
# image_in = np.concatenate([image_in, audio_in], axis=1)
return image_in, image_out, audio_in
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
class image_translation_raw98_with_audio_test_dataset(data.Dataset):
"""
Online landmark extraction with AWings
Landmark setting: 98 landmarks
"""
def __init__(self, num_frames=1):
if platform.release() == '4.4.0-83-generic': # stargazer
self.src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation'
self.mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
else:
self.src_dir = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_compressed_imagetranslation'
self.mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
# self.fls_filenames = glob.glob1(self.src_dir, '*')
self.fls_filenames = np.loadtxt(os.path.join(self.src_dir, 'filename_index.txt'), dtype=str)[:, 1]
self.num_random_frames = num_frames + 1
print(os.name, self.fls_filenames.shape)
def __len__(self):
return self.fls_filenames.shape[0]
def __getitem__(self, item):
"""
Get landmark alignment outside in train_pass()
"""
# load random face
random_fls_filename = self.fls_filenames[max(item - 10, 0)]
mp4_filename = random_fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2]
random_video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
print('============================\nvideo_dir : ' + random_video_dir, item)
random_video = cv2.VideoCapture(random_video_dir)
if (random_video.isOpened() == False):
print('Unable to open video file')
exit(0)
_, random_face = random_video.read()
random_face = cv2.resize(random_face, (256, 256))
fls_filename = self.fls_filenames[item]
# load mp4 file
# ================= raw VOX version ================================
mp4_filename = fls_filename[:-4].split('_x_')
mp4_id = mp4_filename[0].split('_')[-1]
mp4_vname = mp4_filename[1]
mp4_vid = mp4_filename[2]
video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
# print('============================\nvideo_dir : ' + video_dir, item)
# ======================================================================
video = cv2.VideoCapture(video_dir)
if (video.isOpened() == False):
print('Unable to open video file')
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# save video and landmark in parallel
frames = []
for j in range(5, length-5):
ret, img_video = video.read()
img_video = cv2.resize(img_video, (256, 256))
frame = np.concatenate((random_face, img_video), axis=2)
frames.append(frame.transpose((2, 0, 1)))
frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 # N x 256 x 256 x 9
image_in = frames[:, 0:3]
image_out = frames[:, 3:6]
# audio
os.system('ffmpeg -y -loglevel error -i {} -vn -ar 16000 -ac 1 {}'.format(
video_dir, video_dir.replace('.mp4', '.wav')
))
sample_rate, samples = wav.read(video_dir.replace('.mp4', '.wav'))
assert (sample_rate == 16000)
if (len(samples.shape) > 1):
samples = samples[:, 0] # pick mono
# 1 frame = 1/25 * 16k = 640 samples => windowsize=320, overlap=160
# 80 overlap => 200 / 1 sec, 8 / 1 frame
f, t, Zxx = stft(samples, fs=sample_rate, nperseg=640, noverlap=560)
stft_abs = np.log(np.abs(Zxx) ** 2 + 1e-10)
stft_abs = stft_abs / np.max(stft_abs)
os.remove(video_dir.replace('.mp4', '.wav'))
# we want 0.2s before, 5 frames, 40 dims
# and 0.2s after (may remove later)
audio_in = []
for item in range(5, length-5):
sel_audio_clip = stft_abs[:, (item-5)*8:(item+5)*8]
assert sel_audio_clip.shape[1] == 80
audio_in.append(np.expand_dims(cv2.resize(sel_audio_clip, (256, 256)), axis=0))
audio_in = np.stack(audio_in, axis=0).astype(np.float32)
# image_in = np.concatenate([image_in, audio_in], axis=1)
return image_in, image_out, audio_in
def my_collate(self, batch):
batch = filter(lambda x:x is not None, batch)
return default_collate(batch)
if __name__ == '__main__':
d = image_translation_raw_dataset()
d_loader = torch.utils.data.DataLoader(d, batch_size=4, shuffle=True)
print(len(d))
for i, batch in enumerate(d_loader):
print(i, batch[0].shape, batch[1].shape)