File size: 4,971 Bytes
7b1bf28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52e114d
 
7b1bf28
 
 
 
38b22bf
 
 
7b1bf28
38b22bf
 
7b1bf28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52e114d
 
 
7b1bf28
 
 
 
52e114d
7b1bf28
 
 
 
 
 
 
 
 
52e114d
 
7b1bf28
 
 
 
52e114d
7b1bf28
 
 
 
 
 
 
 
 
52e114d
 
7b1bf28
 
 
 
52e114d
7b1bf28
 
52e114d
7b1bf28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import os
import shutil
import torch
from PIL import Image
import argparse
import pathlib

os.system("git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model")
os.chdir("Thin-Plate-Spline-Motion-Model")
os.system("mkdir checkpoints")
os.system("wget -c https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1 -O checkpoints/vox.pth.tar")



title = "# 图片动画"
DESCRIPTION = '''### 图片动画的Gradio实现</b>, CVPR 2022. <a href='https://arxiv.org/abs/2203.14367'>[Paper]</a><a href='https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model'>[Github Code]</a>

<img id="overview" alt="overview" src="https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model/raw/main/assets/vox.gif" />
'''
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.Image-Animation-using-Thin-Plate-Spline-Motion-Model" />'
ARTICLE = r"""
---
<h2 style="font-weight: 900; margin-bottom: 7px;">点击<a href='https://www.toolchest.cn' target='_blank'>返回智能工具箱</a>查看更多好玩的人工智能项目</h2>

```
"""

def get_style_image_path(style_name: str) -> str:
    base_path = 'assets'
    filenames = {
        'source': 'source.png',
        'driving': 'driving.mp4',
    }
    return f'{base_path}/{filenames[style_name]}'


def get_style_image_markdown_text(style_name: str) -> str:
    url = get_style_image_path(style_name)
    return f'<img id="style-image" src="{url}" alt="style image">'


def update_style_image(style_name: str) -> dict:
    text = get_style_image_markdown_text(style_name)
    return gr.Markdown.update(value=text)


def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])

def set_example_video(example: list) -> dict:
    return gr.Video.update(value=example[0])

def inference(img,vid):
  if not os.path.exists('temp'):
    os.system('mkdir temp')
  
  img.save("temp/image.jpg", "JPEG")
  os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4' --cpu")
  return './temp/result.mp4'
  


def main():
    with gr.Blocks(theme="huggingface", css='style.css') as demo:
        gr.Markdown(title)
        gr.Markdown(DESCRIPTION)

        with gr.Box():
            gr.Markdown('''## 第1步 (上传人脸图片)
- 拖一张含人脸的图片到 **输入图片**.
    - 如果图片中有多张人脸, 使用右上角的编辑按钮裁剪图片.
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        input_image = gr.Image(label='输入图片',
                                               type="pil")
                        
            with gr.Row():
                paths = sorted(pathlib.Path('assets').glob('*.png'))
                example_images = gr.Dataset(components=[input_image],
                                            samples=[[path.as_posix()]
                                                     for path in paths])

        with gr.Box():
            gr.Markdown('''## 第2步 (选择动态视频)
-  **为人脸图片选择目标视频**.
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        driving_video = gr.Video(label='目标视频',
                                               format="mp4")

            with gr.Row():
                paths = sorted(pathlib.Path('assets').glob('*.mp4'))
                example_video = gr.Dataset(components=[driving_video],
                                            samples=[[path.as_posix()]
                                                     for path in paths])

        with gr.Box():
            gr.Markdown('''## 第3步 (基于视频生成动态图片)
- 点击 **开始** 按钮. (注意: 由于是在CPU上运行, 生成最终结果需要花费大约3分钟.)
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        generate_button = gr.Button('开始')

                with gr.Column():
                    result = gr.Video(type="file", label="输出")
        gr.Markdown(FOOTER)
        generate_button.click(fn=inference,
                              inputs=[
                                  input_image,
                                  driving_video
                              ],
                              outputs=result)
        example_images.click(fn=set_example_image,
                             inputs=example_images,
                             outputs=example_images.components)
        example_video.click(fn=set_example_video,
                             inputs=example_video,
                             outputs=example_video.components)

    demo.launch(
        enable_queue=True,
        debug=True
    )

if __name__ == '__main__':
    main()