import gradio as gr import os import shutil import torch from PIL import Image import argparse import pathlib os.system("git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model") os.chdir("Thin-Plate-Spline-Motion-Model") os.system("mkdir checkpoints") os.system("wget -c https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1 -O checkpoints/vox.pth.tar") title = "# 图片动画" DESCRIPTION = '''### 图片动画的Gradio实现, CVPR 2022. [Paper][Github Code] overview ''' FOOTER = 'visitor badge' ARTICLE = r""" ---

点击返回智能工具箱查看更多好玩的人工智能项目

``` """ def get_style_image_path(style_name: str) -> str: base_path = 'assets' filenames = { 'source': 'source.png', 'driving': 'driving.mp4', } return f'{base_path}/{filenames[style_name]}' def get_style_image_markdown_text(style_name: str) -> str: url = get_style_image_path(style_name) return f'style image' def update_style_image(style_name: str) -> dict: text = get_style_image_markdown_text(style_name) return gr.Markdown.update(value=text) def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) def set_example_video(example: list) -> dict: return gr.Video.update(value=example[0]) def inference(img,vid): if not os.path.exists('temp'): os.system('mkdir temp') img.save("temp/image.jpg", "JPEG") os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4' --cpu") return './temp/result.mp4' def main(): with gr.Blocks(theme="huggingface", css='style.css') as demo: gr.Markdown(title) gr.Markdown(DESCRIPTION) with gr.Box(): gr.Markdown('''## 第1步 (上传人脸图片) - 拖一张含人脸的图片到 **输入图片**. - 如果图片中有多张人脸, 使用右上角的编辑按钮裁剪图片. ''') with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(label='输入图片', type="pil") with gr.Row(): paths = sorted(pathlib.Path('assets').glob('*.png')) example_images = gr.Dataset(components=[input_image], samples=[[path.as_posix()] for path in paths]) with gr.Box(): gr.Markdown('''## 第2步 (选择动态视频) - **为人脸图片选择目标视频**. ''') with gr.Row(): with gr.Column(): with gr.Row(): driving_video = gr.Video(label='目标视频', format="mp4") with gr.Row(): paths = sorted(pathlib.Path('assets').glob('*.mp4')) example_video = gr.Dataset(components=[driving_video], samples=[[path.as_posix()] for path in paths]) with gr.Box(): gr.Markdown('''## 第3步 (基于视频生成动态图片) - 点击 **开始** 按钮. (注意: 由于是在CPU上运行, 生成最终结果需要花费大约3分钟.) ''') with gr.Row(): with gr.Column(): with gr.Row(): generate_button = gr.Button('开始') with gr.Column(): result = gr.Video(type="file", label="输出") gr.Markdown(FOOTER) generate_button.click(fn=inference, inputs=[ input_image, driving_video ], outputs=result) example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components) example_video.click(fn=set_example_video, inputs=example_video, outputs=example_video.components) demo.launch( enable_queue=True, debug=True ) if __name__ == '__main__': main()