Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import copy | |
import glob | |
import os | |
import re | |
import subprocess | |
from collections import OrderedDict | |
from typing import Dict, List | |
import mediapy | |
import numpy as np | |
import torch | |
import torch as th | |
import torchaudio | |
from attrdict import AttrDict | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
from utils.model_util import get_person_num | |
from visualize.ca_body.utils.image import linear2displayBatch | |
from visualize.ca_body.utils.train import load_checkpoint, load_from_config | |
ffmpeg_header = "ffmpeg -y " # -hide_banner -loglevel error " | |
def filter_params(params, ignore_names): | |
return OrderedDict( | |
[ | |
(k, v) | |
for k, v in params.items() | |
if not any([re.match(n, k) is not None for n in ignore_names]) | |
] | |
) | |
def call_ffmpeg(command: str) -> None: | |
print(command, "-" * 100) | |
e = subprocess.call(command, shell=True) | |
if e != 0: | |
assert False, e | |
class BodyRenderer(th.nn.Module): | |
def __init__( | |
self, | |
config_base: str, | |
render_rgb: bool, | |
): | |
super().__init__() | |
self.config_base = config_base | |
ckpt_path = f"{config_base}/body_dec.ckpt" | |
config_path = f"{config_base}/config.yml" | |
assets_path = f"{config_base}/static_assets.pt" | |
# config | |
config = OmegaConf.load(config_path) | |
gpu = config.get("gpu", 0) | |
self.device = th.device(f"cuda:{gpu}") | |
# assets | |
static_assets = AttrDict(torch.load(assets_path)) | |
# build model | |
self.model = load_from_config(config.model, assets=static_assets).to( | |
self.device | |
) | |
self.model.cal_enabled = False | |
self.model.pixel_cal_enabled = False | |
self.model.learn_blur_enabled = False | |
self.render_rgb = render_rgb | |
if not self.render_rgb: | |
self.model.rendering_enabled = None | |
# load model checkpoints | |
print("loading...", ckpt_path) | |
load_checkpoint( | |
ckpt_path, | |
modules={"model": self.model}, | |
ignore_names={"model": ["lbs_fn.*"]}, | |
) | |
self.model.eval() | |
self.model.to(self.device) | |
# load default parameters for renderer | |
person = get_person_num(config_path) | |
self.default_inputs = th.load(f"assets/render_defaults_{person}.pth") | |
def _write_video_stream( | |
self, motion: np.ndarray, face: np.ndarray, save_name: str | |
) -> None: | |
out = self._render_loop(motion, face) | |
mediapy.write_video(save_name, out, fps=30) | |
def _render_loop(self, body_pose: np.ndarray, face: np.ndarray) -> List[np.ndarray]: | |
all_rgb = [] | |
default_inputs_copy = copy.deepcopy(self.default_inputs) | |
for b in tqdm(range(len(body_pose))): | |
B = default_inputs_copy["K"].shape[0] | |
default_inputs_copy["lbs_motion"] = ( | |
th.tensor(body_pose[b : b + 1, :], device=self.device, dtype=th.float) | |
.tile(B, 1) | |
.to(self.device) | |
) | |
geom = ( | |
self.model.lbs_fn.lbs_fn( | |
default_inputs_copy["lbs_motion"], | |
self.model.lbs_fn.lbs_scale.unsqueeze(0).tile(B, 1), | |
self.model.lbs_fn.lbs_template_verts.unsqueeze(0).tile(B, 1, 1), | |
) | |
* self.model.lbs_fn.global_scaling | |
) | |
default_inputs_copy["geom"] = geom | |
face_codes = ( | |
th.from_numpy(face).float().cuda() if not th.is_tensor(face) else face | |
) | |
curr_face = th.tile(face_codes[b : b + 1, ...], (2, 1)) | |
default_inputs_copy["face_embs"] = curr_face | |
preds = self.model(**default_inputs_copy) | |
rgb0 = linear2displayBatch(preds["rgb"])[0] | |
rgb1 = linear2displayBatch(preds["rgb"])[1] | |
rgb = th.cat((rgb0, rgb1), axis=-1).permute(1, 2, 0) | |
rgb = rgb.clip(0, 255).to(th.uint8) | |
all_rgb.append(rgb.contiguous().detach().byte().cpu().numpy()) | |
return all_rgb | |
def render_full_video( | |
self, | |
data_block: Dict[str, np.ndarray], | |
animation_save_path: str, | |
audio_sr: int = None, | |
render_gt: bool = False, | |
) -> None: | |
tag = os.path.basename(os.path.dirname(animation_save_path)) | |
save_name = os.path.splitext(os.path.basename(animation_save_path))[0] | |
save_name = f"{tag}_{save_name}" | |
torchaudio.save( | |
f"/tmp/audio_{save_name}.wav", | |
torch.tensor(data_block["audio"]), | |
audio_sr, | |
) | |
if render_gt: | |
tag = "gt" | |
self._write_video_stream( | |
data_block["gt_body"], | |
data_block["gt_face"], | |
f"/tmp/{tag}_{save_name}.mp4", | |
) | |
else: | |
tag = "pred" | |
self._write_video_stream( | |
data_block["body_motion"], | |
data_block["face_motion"], | |
f"/tmp/{tag}_{save_name}.mp4", | |
) | |
command = f"{ffmpeg_header} -i /tmp/{tag}_{save_name}.mp4 -i /tmp/audio_{save_name}.wav -c:v copy -map 0:v:0 -map 1:a:0 -c:a aac -b:a 192k -pix_fmt yuva420p {animation_save_path}_{tag}.mp4" | |
call_ffmpeg(command) | |
subprocess.call( | |
f"rm /tmp/audio_{save_name}.wav && rm /tmp/{tag}_{save_name}.mp4", | |
shell=True, | |
) | |