svjack commited on
Commit
d5d4264
·
verified ·
1 Parent(s): 7e94078

Update gradio_app_with_frames.py

Browse files
Files changed (1) hide show
  1. gradio_app_with_frames.py +132 -72
gradio_app_with_frames.py CHANGED
@@ -1,93 +1,158 @@
1
  import os
2
- import sys
3
  import shutil
4
  import uuid
5
  import subprocess
6
  import gradio as gr
7
- import shutil
 
8
  from glob import glob
9
- from huggingface_hub import snapshot_download, hf_hub_download
10
- from moviepy.editor import VideoFileClip # Import MoviePy
11
-
12
- # Download models
13
- os.makedirs("pretrained_weights", exist_ok=True)
14
 
15
- # List of subdirectories to create inside "checkpoints"
16
- subfolders = [
17
- "stable-video-diffusion-img2vid-xt"
18
- ]
19
 
20
- # Create each subdirectory
21
- for subfolder in subfolders:
22
- os.makedirs(os.path.join("pretrained_weights", subfolder), exist_ok=True)
23
 
24
- snapshot_download(
25
- repo_id="stabilityai/stable-video-diffusion-img2vid",
26
- local_dir="./pretrained_weights/stable-video-diffusion-img2vid-xt"
27
- )
 
 
 
28
 
29
- snapshot_download(
30
- repo_id="Yhmeng1106/anidoc",
31
- local_dir="./pretrained_weights"
32
- )
33
-
34
- hf_hub_download(
35
- repo_id="facebook/cotracker",
36
- filename="cotracker2.pth",
37
- local_dir="./pretrained_weights"
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def generate(control_sequence, ref_image):
41
- control_image = control_sequence # "data_test/sample4.mp4"
42
- ref_image = ref_image # "data_test/sample4.png"
43
- unique_id = str(uuid.uuid4())
44
- output_dir = f"results_{unique_id}"
45
-
46
  try:
47
- # Use MoviePy to get the number of frames in the control_sequence video
48
- video_clip = VideoFileClip(control_image)
49
- num_frames = int(video_clip.fps * video_clip.duration) # Calculate total frames
50
- video_clip.close() # Close the video clip to free resources
51
-
52
- # Run the inference command
53
- subprocess.run(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  [
55
- "python", "scripts_infer/anidoc_inference.py",
56
  "--all_sketch",
57
  "--matching",
58
  "--tracking",
59
- "--control_image", f"{control_image}",
60
- "--ref_image", f"{ref_image}",
61
- "--output_dir", f"{output_dir}",
62
  "--max_point", "10",
63
- "--num_frames", str(num_frames) # Pass the calculated num_frames
64
  ],
65
- check=True
 
 
66
  )
 
 
 
 
 
67
 
68
- # Search for the mp4 file in a subfolder of output_dir
69
- output_video = glob(os.path.join(output_dir, "*.mp4"))
70
- print(output_video)
71
 
72
  if output_video:
73
- output_video_path = output_video[0] # Get the first match
 
74
  else:
75
- output_video_path = None
 
 
 
 
 
 
 
 
 
 
76
 
77
- print(output_video_path)
78
  return output_video_path
79
-
80
  except subprocess.CalledProcessError as e:
81
- raise gr.Error(f"Error during inference: {str(e)}")
 
82
  except Exception as e:
83
- raise gr.Error(f"Error processing video: {str(e)}")
84
 
85
- css = """
86
  div#col-container{
87
  margin: 0 auto;
88
  max-width: 982px;
89
  }
90
  """
 
91
  with gr.Blocks(css=css) as demo:
92
  with gr.Column(elem_id="col-container"):
93
  gr.Markdown("# AniDoc: Animation Creation Made Easier")
@@ -103,12 +168,6 @@ with gr.Blocks(css=css) as demo:
103
  <a href="https://arxiv.org/pdf/2412.14173">
104
  <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
105
  </a>
106
- <a href="https://huggingface.co/spaces/fffiloni/AniDoc?duplicate=true">
107
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
108
- </a>
109
- <a href="https://huggingface.co/fffiloni">
110
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
111
- </a>
112
  </div>
113
  """)
114
  with gr.Row():
@@ -120,19 +179,20 @@ with gr.Blocks(css=css) as demo:
120
  video_result = gr.Video(label="Result")
121
 
122
  gr.Examples(
123
- examples=[
 
124
  ["data_test/sample1.mp4", "data_test/sample1.png"],
125
  ["data_test/sample2.mp4", "data_test/sample2.png"],
126
  ["data_test/sample3.mp4", "data_test/sample3.png"],
127
  ["data_test/sample4.mp4", "data_test/sample4.png"]
128
  ],
129
- inputs=[control_sequence, ref_image]
130
- )
131
-
132
- submit_btn.click(
133
- fn=generate,
134
- inputs=[control_sequence, ref_image],
135
- outputs=[video_result]
136
- )
137
 
138
- demo.queue().launch(show_api=False, show_error=True, share=True)
 
1
  import os
 
2
  import shutil
3
  import uuid
4
  import subprocess
5
  import gradio as gr
6
+ import cv2
7
+ import sys
8
  from glob import glob
9
+ from pathlib import Path
 
 
 
 
10
 
11
+ # 获取当前Python解释器路径
12
+ #PYTHON_EXECUTABLE = sys.executable
13
+ PYTHON_EXECUTABLE = "python"
 
14
 
15
+ def normalize_path(path: str) -> str:
16
+ """标准化路径,将Windows路径转换为正斜杠形式"""
17
+ return str(Path(path).resolve()).replace('\\', '/')
18
 
19
+ def check_video_frames(video_path: str) -> int:
20
+ """检查视频帧数"""
21
+ video_path = normalize_path(video_path)
22
+ cap = cv2.VideoCapture(video_path)
23
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
24
+ cap.release()
25
+ return frame_count
26
 
27
+ def preprocess_video(video_path: str) -> str:
28
+ """预处理视频到14帧"""
29
+ try:
30
+ video_path = normalize_path(video_path)
31
+ unique_id = str(uuid.uuid4())
32
+ temp_dir = "outputs"
33
+ output_dir = os.path.join(temp_dir, f"processed_{unique_id}")
34
+ output_dir = normalize_path(output_dir)
35
+ os.makedirs(output_dir, exist_ok=True)
36
+
37
+ print(f"Processing video: {video_path}")
38
+ print(f"Output directory: {output_dir}")
39
+
40
+ # 调用process_video_to_14frames.py处理视频
41
+ result = subprocess.run(
42
+ [
43
+ PYTHON_EXECUTABLE, "process_video_to_14frames.py",
44
+ "--input", video_path,
45
+ "--output", output_dir
46
+ ],
47
+ check=True,
48
+ capture_output=True,
49
+ text=True
50
+ )
51
+
52
+ if result.stdout:
53
+ print(f"Preprocessing stdout: {result.stdout}")
54
+ if result.stderr:
55
+ print(f"Preprocessing stderr: {result.stderr}")
56
+
57
+ # 获取处理后的视频路径
58
+ processed_videos = glob(os.path.join(output_dir, "*.mp4"))
59
+ if not processed_videos:
60
+ raise gr.Error("Failed to process video: No output video found")
61
+ return normalize_path(processed_videos[0])
62
+ except subprocess.CalledProcessError as e:
63
+ print(f"Preprocessing stderr: {e.stderr}")
64
+ raise gr.Error(f"Failed to preprocess video: {e.stderr}")
65
+ except Exception as e:
66
+ raise gr.Error(f"Error during video preprocessing: {str(e)}")
67
 
68
  def generate(control_sequence, ref_image):
 
 
 
 
 
69
  try:
70
+ # 验证输入文件是否存在
71
+ control_sequence = normalize_path(control_sequence)
72
+ ref_image = normalize_path(ref_image)
73
+
74
+ if not os.path.exists(control_sequence):
75
+ raise gr.Error(f"Control sequence file not found: {control_sequence}")
76
+ if not os.path.exists(ref_image):
77
+ raise gr.Error(f"Reference image file not found: {ref_image}")
78
+
79
+ # 创建输出目录
80
+ output_dir = "outputs"
81
+ os.makedirs(output_dir, exist_ok=True)
82
+ unique_id = str(uuid.uuid4())
83
+ result_dir = os.path.join(output_dir, f"results_{unique_id}")
84
+ result_dir = normalize_path(result_dir)
85
+ os.makedirs(result_dir, exist_ok=True)
86
+
87
+ print(f"Input control sequence: {control_sequence}")
88
+ print(f"Input reference image: {ref_image}")
89
+ print(f"Output directory: {result_dir}")
90
+
91
+ # 检查视频帧数
92
+ frame_count = check_video_frames(control_sequence)
93
+ if frame_count != 14:
94
+ print(f"Video has {frame_count} frames, preprocessing to 14 frames...")
95
+ control_sequence = preprocess_video(control_sequence)
96
+ print(f"Preprocessed video saved to: {control_sequence}")
97
+
98
+ # 运行推理命令
99
+ print(f"Running inference...")
100
+ result = subprocess.run(
101
  [
102
+ PYTHON_EXECUTABLE, "scripts_infer/anidoc_inference.py",
103
  "--all_sketch",
104
  "--matching",
105
  "--tracking",
106
+ "--control_image", control_sequence,
107
+ "--ref_image", ref_image,
108
+ "--output_dir", result_dir,
109
  "--max_point", "10",
 
110
  ],
111
+ check=True,
112
+ capture_output=True,
113
+ text=True
114
  )
115
+
116
+ if result.stdout:
117
+ print(f"Inference stdout: {result.stdout}")
118
+ if result.stderr:
119
+ print(f"Inference stderr: {result.stderr}")
120
 
121
+ # 搜索输出视频
122
+ output_video = glob(os.path.join(result_dir, "*.mp4"))
123
+ print(f"Found output videos: {output_video}")
124
 
125
  if output_video:
126
+ output_video_path = normalize_path(output_video[0])
127
+ print(f"Returning output video: {output_video_path}")
128
  else:
129
+ raise gr.Error("No output video generated")
130
+
131
+ # 清理临时文件
132
+ temp_dirs = glob("outputs/processed_*")
133
+ for temp_dir in temp_dirs:
134
+ if os.path.isdir(temp_dir):
135
+ try:
136
+ shutil.rmtree(temp_dir)
137
+ print(f"Cleaned up temp directory: {temp_dir}")
138
+ except Exception as e:
139
+ print(f"Warning: Failed to clean up temp directory {temp_dir}: {str(e)}")
140
 
 
141
  return output_video_path
142
+
143
  except subprocess.CalledProcessError as e:
144
+ print(f"Inference stderr: {e.stderr}")
145
+ raise gr.Error(f"Error during inference: {e.stderr}")
146
  except Exception as e:
147
+ raise gr.Error(f"Error: {str(e)}")
148
 
149
+ css="""
150
  div#col-container{
151
  margin: 0 auto;
152
  max-width: 982px;
153
  }
154
  """
155
+
156
  with gr.Blocks(css=css) as demo:
157
  with gr.Column(elem_id="col-container"):
158
  gr.Markdown("# AniDoc: Animation Creation Made Easier")
 
168
  <a href="https://arxiv.org/pdf/2412.14173">
169
  <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
170
  </a>
 
 
 
 
 
 
171
  </div>
172
  """)
173
  with gr.Row():
 
179
  video_result = gr.Video(label="Result")
180
 
181
  gr.Examples(
182
+ examples = [
183
+ ["data_test/sample5.mp4", "data_test/sample5.png"],
184
  ["data_test/sample1.mp4", "data_test/sample1.png"],
185
  ["data_test/sample2.mp4", "data_test/sample2.png"],
186
  ["data_test/sample3.mp4", "data_test/sample3.png"],
187
  ["data_test/sample4.mp4", "data_test/sample4.png"]
188
  ],
189
+ inputs = [control_sequence, ref_image]
190
+ )
191
+
192
+ submit_btn.click(
193
+ fn = generate,
194
+ inputs = [control_sequence, ref_image],
195
+ outputs = [video_result]
196
+ )
197
 
198
+ demo.queue().launch(inbrowser=True,show_api=False, show_error=True, share = True)