File size: 4,971 Bytes
7b1bf28 52e114d 7b1bf28 38b22bf 7b1bf28 38b22bf 7b1bf28 52e114d 7b1bf28 52e114d 7b1bf28 52e114d 7b1bf28 52e114d 7b1bf28 52e114d 7b1bf28 52e114d 7b1bf28 52e114d 7b1bf28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import os
import shutil
import torch
from PIL import Image
import argparse
import pathlib
os.system("git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model")
os.chdir("Thin-Plate-Spline-Motion-Model")
os.system("mkdir checkpoints")
os.system("wget -c https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1 -O checkpoints/vox.pth.tar")
title = "# 图片动画"
DESCRIPTION = '''### 图片动画的Gradio实现</b>, CVPR 2022. <a href='https://arxiv.org/abs/2203.14367'>[Paper]</a><a href='https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model'>[Github Code]</a>
<img id="overview" alt="overview" src="https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model/raw/main/assets/vox.gif" />
'''
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.Image-Animation-using-Thin-Plate-Spline-Motion-Model" />'
ARTICLE = r"""
---
<h2 style="font-weight: 900; margin-bottom: 7px;">点击<a href='https://www.toolchest.cn' target='_blank'>返回智能工具箱</a>查看更多好玩的人工智能项目</h2>
```
"""
def get_style_image_path(style_name: str) -> str:
base_path = 'assets'
filenames = {
'source': 'source.png',
'driving': 'driving.mp4',
}
return f'{base_path}/{filenames[style_name]}'
def get_style_image_markdown_text(style_name: str) -> str:
url = get_style_image_path(style_name)
return f'<img id="style-image" src="{url}" alt="style image">'
def update_style_image(style_name: str) -> dict:
text = get_style_image_markdown_text(style_name)
return gr.Markdown.update(value=text)
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
def set_example_video(example: list) -> dict:
return gr.Video.update(value=example[0])
def inference(img,vid):
if not os.path.exists('temp'):
os.system('mkdir temp')
img.save("temp/image.jpg", "JPEG")
os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4' --cpu")
return './temp/result.mp4'
def main():
with gr.Blocks(theme="huggingface", css='style.css') as demo:
gr.Markdown(title)
gr.Markdown(DESCRIPTION)
with gr.Box():
gr.Markdown('''## 第1步 (上传人脸图片)
- 拖一张含人脸的图片到 **输入图片**.
- 如果图片中有多张人脸, 使用右上角的编辑按钮裁剪图片.
''')
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(label='输入图片',
type="pil")
with gr.Row():
paths = sorted(pathlib.Path('assets').glob('*.png'))
example_images = gr.Dataset(components=[input_image],
samples=[[path.as_posix()]
for path in paths])
with gr.Box():
gr.Markdown('''## 第2步 (选择动态视频)
- **为人脸图片选择目标视频**.
''')
with gr.Row():
with gr.Column():
with gr.Row():
driving_video = gr.Video(label='目标视频',
format="mp4")
with gr.Row():
paths = sorted(pathlib.Path('assets').glob('*.mp4'))
example_video = gr.Dataset(components=[driving_video],
samples=[[path.as_posix()]
for path in paths])
with gr.Box():
gr.Markdown('''## 第3步 (基于视频生成动态图片)
- 点击 **开始** 按钮. (注意: 由于是在CPU上运行, 生成最终结果需要花费大约3分钟.)
''')
with gr.Row():
with gr.Column():
with gr.Row():
generate_button = gr.Button('开始')
with gr.Column():
result = gr.Video(type="file", label="输出")
gr.Markdown(FOOTER)
generate_button.click(fn=inference,
inputs=[
input_image,
driving_video
],
outputs=result)
example_images.click(fn=set_example_image,
inputs=example_images,
outputs=example_images.components)
example_video.click(fn=set_example_video,
inputs=example_video,
outputs=example_video.components)
demo.launch(
enable_queue=True,
debug=True
)
if __name__ == '__main__':
main() |