multimodalart's picture
Upload 83 files
38e20ed verified
raw
history blame
10.6 kB
import math
import tempfile
import warnings
from pathlib import Path
import cv2
import librosa
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from pydantic import BaseModel
from .diff_talking_head import DiffTalkingHead
from .utils import NullableArgs, coef_dict_to_vertices, get_coef_dict
from .utils.media import combine_video_and_audio, convert_video, reencode_audio
warnings.filterwarnings('ignore', message='PySoundFile failed. Trying audioread instead.')
class DiffPoseTalkConfig(BaseModel):
no_context_audio_feat: bool = False
model_path: str = "pretrained_models/diffposetalk/iter_0110000.pt" # DPT/head-SA-hubert-WM
coef_stats: str = "pretrained_models/diffposetalk/stats_train.npz"
style_path: str = "pretrained_models/diffposetalk/style/L4H4-T0.1-BS32/iter_0034000/normal.npy"
dynamic_threshold_ratio: float = 0.99
dynamic_threshold_min: float = 1.0
dynamic_threshold_max: float = 4.0
scale_audio: float = 1.15
scale_style: float = 3.0
class DiffPoseTalk:
def __init__(self, config: DiffPoseTalkConfig = DiffPoseTalkConfig(), device="cuda"):
self.cfg = config
self.device = device
self.no_context_audio_feat = self.cfg.no_context_audio_feat
model_data = torch.load(self.cfg.model_path, map_location=self.device)
self.model_args = NullableArgs(model_data['args'])
self.model = DiffTalkingHead(self.model_args, self.device)
model_data['model'].pop('denoising_net.TE.pe')
self.model.load_state_dict(model_data['model'], strict=False)
self.model.to(self.device)
self.model.eval()
self.use_indicator = self.model_args.use_indicator
self.rot_repr = self.model_args.rot_repr
self.predict_head_pose = not self.model_args.no_head_pose
if self.model.use_style:
style_dir = Path(self.model_args.style_enc_ckpt)
style_dir = Path(*style_dir.with_suffix('').parts[-3::2])
self.style_dir = style_dir
# sequence
self.n_motions = self.model_args.n_motions
self.n_prev_motions = self.model_args.n_prev_motions
self.fps = self.model_args.fps
self.audio_unit = 16000. / self.fps # num of samples per frame
self.n_audio_samples = round(self.audio_unit * self.n_motions)
self.pad_mode = self.model_args.pad_mode
self.coef_stats = dict(np.load(self.cfg.coef_stats))
self.coef_stats = {k: torch.from_numpy(v).to(self.device) for k, v in self.coef_stats.items()}
if self.cfg.dynamic_threshold_ratio > 0:
self.dynamic_threshold = (self.cfg.dynamic_threshold_ratio, self.cfg.dynamic_threshold_min,
self.cfg.dynamic_threshold_max)
else:
self.dynamic_threshold = None
def infer_from_file(self, audio_path, shape_coef):
n_repetitions = 1
cfg_mode = None
cfg_cond = self.model.guiding_conditions
cfg_scale = []
for cond in cfg_cond:
if cond == 'audio':
cfg_scale.append(self.cfg.scale_audio)
elif cond == 'style':
cfg_scale.append(self.cfg.scale_style)
coef_dict = self.infer_coeffs(audio_path, shape_coef, self.cfg.style_path, n_repetitions,
cfg_mode, cfg_cond, cfg_scale, include_shape=True)
return coef_dict
@torch.no_grad()
def infer_coeffs(self, audio, shape_coef, style_feat=None, n_repetitions=1,
cfg_mode=None, cfg_cond=None, cfg_scale=1.15, include_shape=False):
# Returns dict[str, (n_repetitions, L, *)]
# Step 1: Preprocessing
# Preprocess audio
if isinstance(audio, (str, Path)):
audio, _ = librosa.load(audio, sr=16000, mono=True)
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).to(self.device)
assert audio.ndim == 1, 'Audio must be 1D tensor.'
audio_mean, audio_std = torch.mean(audio), torch.std(audio)
audio = (audio - audio_mean) / (audio_std + 1e-5)
# Preprocess shape coefficient
if isinstance(shape_coef, (str, Path)):
shape_coef = np.load(shape_coef)
if not isinstance(shape_coef, np.ndarray):
shape_coef = shape_coef['shape']
if isinstance(shape_coef, np.ndarray):
shape_coef = torch.from_numpy(shape_coef).float().to(self.device)
assert shape_coef.ndim <= 2, 'Shape coefficient must be 1D or 2D tensor.'
if shape_coef.ndim > 1:
# use the first frame as the shape coefficient
shape_coef = shape_coef[0]
original_shape_coef = shape_coef.clone()
if self.coef_stats is not None:
shape_coef = (shape_coef - self.coef_stats['shape_mean']) / self.coef_stats['shape_std']
shape_coef = shape_coef.unsqueeze(0).expand(n_repetitions, -1)
# Preprocess style feature if given
if style_feat is not None:
assert self.model.use_style
if isinstance(style_feat, (str, Path)):
style_feat = Path(style_feat)
if not style_feat.exists() and not style_feat.is_absolute():
style_feat = style_feat.parent / self.style_dir / style_feat.name
style_feat = np.load(style_feat)
if not isinstance(style_feat, np.ndarray):
style_feat = style_feat['style']
if isinstance(style_feat, np.ndarray):
style_feat = torch.from_numpy(style_feat).float().to(self.device)
assert style_feat.ndim == 1, 'Style feature must be 1D tensor.'
style_feat = style_feat.unsqueeze(0).expand(n_repetitions, -1)
# Step 2: Predict motion coef
# divide into synthesize units and do synthesize
clip_len = int(len(audio) / 16000 * self.fps)
stride = self.n_motions
if clip_len <= self.n_motions:
n_subdivision = 1
else:
n_subdivision = math.ceil(clip_len / stride)
# Prepare audio input
n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio)
n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
if n_padding_audio_samples > 0:
if self.pad_mode == 'zero':
padding_value = 0
elif self.pad_mode == 'replicate':
padding_value = audio[-1]
else:
raise ValueError(f'Unknown pad mode: {self.pad_mode}')
audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value)
if not self.no_context_audio_feat:
audio_feat = self.model.extract_audio_feature(audio.unsqueeze(0), self.n_motions * n_subdivision)
# Generate `self.n_motions` new frames at one time, and use the last `self.n_prev_motions` frames
# from the previous generation as the initial motion condition
coef_list = []
for i in range(0, n_subdivision):
start_idx = i * stride
end_idx = start_idx + self.n_motions
indicator = torch.ones((n_repetitions, self.n_motions)).to(self.device) if self.use_indicator else None
if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0:
indicator[:, -n_padding_frames:] = 0
if not self.no_context_audio_feat:
audio_in = audio_feat[:, start_idx:end_idx].expand(n_repetitions, -1, -1)
else:
audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0)
# generate motion coefficients
if i == 0:
# -> (N, L, d_motion=n_code_per_frame * code_dim)
motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
indicator=indicator, cfg_mode=cfg_mode,
cfg_cond=cfg_cond, cfg_scale=cfg_scale,
dynamic_threshold=self.dynamic_threshold)
else:
motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
prev_motion_feat, prev_audio_feat, noise,
indicator=indicator, cfg_mode=cfg_mode,
cfg_cond=cfg_cond, cfg_scale=cfg_scale,
dynamic_threshold=self.dynamic_threshold)
prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone()
prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:]
motion_coef = motion_feat
if i == n_subdivision - 1 and n_padding_frames > 0:
motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames
coef_list.append(motion_coef)
motion_coef = torch.cat(coef_list, dim=1)
# Step 3: restore to coef dict
coef_dict = get_coef_dict(motion_coef, None, self.coef_stats, self.predict_head_pose, self.rot_repr)
if include_shape:
coef_dict['shape'] = original_shape_coef[None, None].expand(n_repetitions, motion_coef.shape[1], -1)
return self.coef_to_a1_format(coef_dict)
def coef_to_a1_format(self, coef_dict):
n_frames = coef_dict['exp'].shape[1]
new_coef_dict = []
for i in range(n_frames):
new_coef_dict.append({
"expression_params": coef_dict["exp"][0, i:i+1],
"jaw_params": coef_dict["pose"][0, i:i+1, 3:],
"eye_pose_params": torch.zeros(1, 6).type_as(coef_dict["pose"]),
"pose_params": coef_dict["pose"][0, i:i+1, :3],
"eyelid_params": None
})
return new_coef_dict
@staticmethod
def _pad_coef(coef, n_frames, elem_ndim=1):
if coef.ndim == elem_ndim:
coef = coef[None]
elem_shape = coef.shape[1:]
if coef.shape[0] >= n_frames:
new_coef = coef[:n_frames]
else:
# repeat the last coef frame
new_coef = torch.cat([coef, coef[[-1]].expand(n_frames - coef.shape[0], *elem_shape)], dim=0)
return new_coef # (n_frames, *elem_shape)