Delik commited on
Commit
e2a3041
·
verified ·
1 Parent(s): a4964be

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -403
app.py DELETED
@@ -1,403 +0,0 @@
1
-
2
- import argparse
3
- from datetime import datetime
4
- from pathlib import Path
5
- import numpy as np
6
- import torch
7
- from PIL import Image
8
- import gradio as gr
9
- import shutil
10
- import librosa
11
- import python_speech_features
12
- import time
13
- from LIA_Model import LIA_Model
14
- import os
15
- from tqdm import tqdm
16
- import argparse
17
- import numpy as np
18
- from torchvision import transforms
19
- from templates import *
20
- import argparse
21
- import shutil
22
- from moviepy.editor import *
23
- import librosa
24
- import python_speech_features
25
- import importlib.util
26
- import time
27
- import os
28
- import time
29
- import numpy as np
30
-
31
-
32
-
33
- # Disable Gradio analytics to avoid network-related issues
34
- gr.analytics_enabled = False
35
-
36
-
37
- def check_package_installed(package_name):
38
- package_spec = importlib.util.find_spec(package_name)
39
- if package_spec is None:
40
- print(f"{package_name} is not installed.")
41
- return False
42
- else:
43
- print(f"{package_name} is installed.")
44
- return True
45
-
46
- def frames_to_video(input_path, audio_path, output_path, fps=25):
47
- image_files = [os.path.join(input_path, img) for img in sorted(os.listdir(input_path))]
48
- clips = [ImageClip(m).set_duration(1/fps) for m in image_files]
49
- video = concatenate_videoclips(clips, method="compose")
50
-
51
- audio = AudioFileClip(audio_path)
52
- final_video = video.set_audio(audio)
53
- final_video.write_videofile(output_path, fps=fps, codec='libx264', audio_codec='aac')
54
-
55
- def load_image(filename, size):
56
- img = Image.open(filename).convert('RGB')
57
- img = img.resize((size, size))
58
- img = np.asarray(img)
59
- img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
60
- return img / 255.0
61
-
62
- def img_preprocessing(img_path, size):
63
- img = load_image(img_path, size) # [0, 1]
64
- img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
65
- imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
66
- return imgs_norm
67
-
68
- def saved_image(img_tensor, img_path):
69
- toPIL = transforms.ToPILImage()
70
- img = toPIL(img_tensor.detach().cpu().squeeze(0)) # 使用squeeze(0)来移除批次维度
71
- img.save(img_path)
72
-
73
- def main(args):
74
- frames_result_saved_path = os.path.join(args.result_path, 'frames')
75
- os.makedirs(frames_result_saved_path, exist_ok=True)
76
- test_image_name = os.path.splitext(os.path.basename(args.test_image_path))[0]
77
- audio_name = os.path.splitext(os.path.basename(args.test_audio_path))[0]
78
- predicted_video_256_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}.mp4')
79
- predicted_video_512_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}_SR.mp4')
80
-
81
- #======Loading Stage 1 model=========
82
- lia = LIA_Model(motion_dim=args.motion_dim, fusion_type='weighted_sum')
83
- lia.load_lightning_model(args.stage1_checkpoint_path)
84
- lia.to(args.device)
85
- #============================
86
-
87
- conf = ffhq256_autoenc()
88
- conf.seed = args.seed
89
- conf.decoder_layers = args.decoder_layers
90
- conf.infer_type = args.infer_type
91
- conf.motion_dim = args.motion_dim
92
-
93
- if args.infer_type == 'mfcc_full_control':
94
- conf.face_location=True
95
- conf.face_scale=True
96
- conf.mfcc = True
97
- elif args.infer_type == 'mfcc_pose_only':
98
- conf.face_location=False
99
- conf.face_scale=False
100
- conf.mfcc = True
101
- elif args.infer_type == 'hubert_pose_only':
102
- conf.face_location=False
103
- conf.face_scale=False
104
- conf.mfcc = False
105
- elif args.infer_type == 'hubert_audio_only':
106
- conf.face_location=False
107
- conf.face_scale=False
108
- conf.mfcc = False
109
- elif args.infer_type == 'hubert_full_control':
110
- conf.face_location=True
111
- conf.face_scale=True
112
- conf.mfcc = False
113
- else:
114
- print('Type NOT Found!')
115
- exit(0)
116
-
117
- if not os.path.exists(args.test_image_path):
118
- print(f'{args.test_image_path} does not exist!')
119
- exit(0)
120
-
121
- if not os.path.exists(args.test_audio_path):
122
- print(f'{args.test_audio_path} does not exist!')
123
- exit(0)
124
-
125
- img_source = img_preprocessing(args.test_image_path, args.image_size).to(args.device)
126
- one_shot_lia_start, one_shot_lia_direction, feats = lia.get_start_direction_code(img_source, img_source, img_source, img_source)
127
-
128
- #======Loading Stage 2 model=========
129
- model = LitModel(conf)
130
- state = torch.load(args.stage2_checkpoint_path, map_location='cpu')
131
- model.load_state_dict(state, strict=True)
132
- model.ema_model.eval()
133
- model.ema_model.to(args.device)
134
- #=================================
135
-
136
- #======Audio Input=========
137
- if conf.infer_type.startswith('mfcc'):
138
- # MFCC features
139
- wav, sr = librosa.load(args.test_audio_path, sr=16000)
140
- input_values = python_speech_features.mfcc(signal=wav, samplerate=sr, numcep=13, winlen=0.025, winstep=0.01)
141
- d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
142
- d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
143
- audio_driven_obj = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
144
- frame_start, frame_end = 0, int(audio_driven_obj.shape[0]/4)
145
- audio_start, audio_end = int(frame_start * 4), int(frame_end * 4) # The video frame is fixed to 25 hz and the audio is fixed to 100 hz
146
-
147
- audio_driven = torch.Tensor(audio_driven_obj[audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
148
-
149
- elif conf.infer_type.startswith('hubert'):
150
- # Hubert features
151
- if not os.path.exists(args.test_hubert_path):
152
-
153
- if not check_package_installed('transformers'):
154
- print('Please install transformers module first.')
155
- exit(0)
156
- hubert_model_path = './ckpts/chinese-hubert-large'
157
- if not os.path.exists(hubert_model_path):
158
- print('Please download the hubert weight into the ckpts path first.')
159
- exit(0)
160
- print('You did not extract the audio features in advance, extracting online now, which will increase processing delay')
161
-
162
- start_time = time.time()
163
-
164
- # load hubert model
165
- from transformers import Wav2Vec2FeatureExtractor, HubertModel
166
- audio_model = HubertModel.from_pretrained(hubert_model_path).to(args.device)
167
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_path)
168
- audio_model.feature_extractor._freeze_parameters()
169
- audio_model.eval()
170
-
171
- # hubert model forward pass
172
- audio, sr = librosa.load(args.test_audio_path, sr=16000)
173
- input_values = feature_extractor(audio, sampling_rate=16000, padding=True, do_normalize=True, return_tensors="pt").input_values
174
- input_values = input_values.to(args.device)
175
- ws_feats = []
176
- with torch.no_grad():
177
- outputs = audio_model(input_values, output_hidden_states=True)
178
- for i in range(len(outputs.hidden_states)):
179
- ws_feats.append(outputs.hidden_states[i].detach().cpu().numpy())
180
- ws_feat_obj = np.array(ws_feats)
181
- ws_feat_obj = np.squeeze(ws_feat_obj, 1)
182
- ws_feat_obj = np.pad(ws_feat_obj, ((0, 0), (0, 1), (0, 0)), 'edge') # align the audio length with video frame
183
-
184
- execution_time = time.time() - start_time
185
- print(f"Extraction Audio Feature: {execution_time:.2f} Seconds")
186
-
187
- audio_driven_obj = ws_feat_obj
188
- else:
189
- print(f'Using audio feature from path: {args.test_hubert_path}')
190
- audio_driven_obj = np.load(args.test_hubert_path)
191
-
192
- frame_start, frame_end = 0, int(audio_driven_obj.shape[1]/2)
193
- audio_start, audio_end = int(frame_start * 2), int(frame_end * 2) # The video frame is fixed to 25 hz and the audio is fixed to 50 hz
194
-
195
- audio_driven = torch.Tensor(audio_driven_obj[:,audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
196
- #============================
197
-
198
- # Diffusion Noise
199
- noisyT = torch.randn((1,frame_end, args.motion_dim)).to(args.device)
200
-
201
- #======Inputs for Attribute Control=========
202
- if os.path.exists(args.pose_driven_path):
203
- pose_obj = np.load(args.pose_driven_path)
204
-
205
- if len(pose_obj.shape) != 2:
206
- print('please check your pose information. The shape must be like (T, 3).')
207
- exit(0)
208
- if pose_obj.shape[1] != 3:
209
- print('please check your pose information. The shape must be like (T, 3).')
210
- exit(0)
211
-
212
- if pose_obj.shape[0] >= frame_end:
213
- pose_obj = pose_obj[:frame_end,:]
214
- else:
215
- padding = np.tile(pose_obj[-1, :], (frame_end - pose_obj.shape[0], 1))
216
- pose_obj = np.vstack((pose_obj, padding))
217
-
218
- pose_signal = torch.Tensor(pose_obj).unsqueeze(0).to(args.device) / 90 # 90 is for normalization here
219
- else:
220
- yaw_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_yaw
221
- pitch_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_pitch
222
- roll_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_roll
223
- pose_signal = torch.cat((yaw_signal, pitch_signal, roll_signal), dim=-1)
224
-
225
- pose_signal = torch.clamp(pose_signal, -1, 1)
226
-
227
- face_location_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_location
228
- face_scae_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_scale
229
- #===========================================
230
-
231
- start_time = time.time()
232
-
233
- #======Diffusion Denosing Process=========
234
- generated_directions = model.render(one_shot_lia_start, one_shot_lia_direction, audio_driven, face_location_signal, face_scae_signal, pose_signal, noisyT, args.step_T, control_flag=args.control_flag)
235
- #=========================================
236
-
237
- execution_time = time.time() - start_time
238
- print(f"Motion Diffusion Model: {execution_time:.2f} Seconds")
239
-
240
- generated_directions = generated_directions.detach().cpu().numpy()
241
-
242
- start_time = time.time()
243
- #======Rendering images frame-by-frame=========
244
- for pred_index in tqdm(range(generated_directions.shape[1])):
245
- ori_img_recon = lia.render(one_shot_lia_start, torch.Tensor(generated_directions[:,pred_index,:]).to(args.device), feats)
246
- ori_img_recon = ori_img_recon.clamp(-1, 1)
247
- wav_pred = (ori_img_recon.detach() + 1) / 2
248
- saved_image(wav_pred, os.path.join(frames_result_saved_path, "%06d.png"%(pred_index)))
249
- #==============================================
250
-
251
- execution_time = time.time() - start_time
252
- print(f"Renderer Model: {execution_time:.2f} Seconds")
253
-
254
- frames_to_video(frames_result_saved_path, args.test_audio_path, predicted_video_256_path)
255
-
256
- shutil.rmtree(frames_result_saved_path)
257
-
258
- # Enhancer
259
- if args.face_sr and check_package_installed('gfpgan'):
260
- from face_sr.face_enhancer import enhancer_list
261
- import imageio
262
-
263
- # Super-resolution
264
- imageio.mimsave(predicted_video_512_path+'.tmp.mp4', enhancer_list(predicted_video_256_path, method='gfpgan', bg_upsampler=None), fps=float(25))
265
-
266
- # Merge audio and video
267
- video_clip = VideoFileClip(predicted_video_512_path+'.tmp.mp4')
268
- audio_clip = AudioFileClip(predicted_video_256_path)
269
- final_clip = video_clip.set_audio(audio_clip)
270
- final_clip.write_videofile(predicted_video_512_path, codec='libx264', audio_codec='aac')
271
-
272
- os.remove(predicted_video_512_path+'.tmp.mp4')
273
-
274
- if args.face_sr:
275
- return predicted_video_256_path, predicted_video_512_path
276
- else:
277
- return predicted_video_256_path, predicted_video_256_path
278
-
279
- def generate_video(uploaded_img, uploaded_audio, infer_type,
280
- pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed):
281
- if uploaded_img is None or uploaded_audio is None:
282
- return None, gr.Markdown("Error: Input image or audio file is empty. Please check and upload both files.")
283
-
284
- model_mapping = {
285
- "mfcc_pose_only": "./ckpts/stage2_pose_only_mfcc.ckpt",
286
- "mfcc_full_control": "./ckpts/stage2_more_controllable_mfcc.ckpt",
287
- "hubert_audio_only": "./ckpts/stage2_audio_only_hubert.ckpt",
288
- "hubert_pose_only": "./ckpts/stage2_pose_only_hubert.ckpt",
289
- "hubert_full_control": "./ckpts/stage2_full_control_hubert.ckpt",
290
- }
291
-
292
- # if face_crop:
293
- # uploaded_img_path = Path(uploaded_img)
294
- # cropped_img_path = uploaded_img_path.with_name(uploaded_img_path.stem + "_crop" + uploaded_img_path.suffix)
295
- # crop_image(uploaded_img, cropped_img_path)
296
- # uploaded_img = str(cropped_img_path)
297
-
298
- # import pdb;pdb.set_trace()
299
-
300
- stage2_checkpoint_path = model_mapping.get(infer_type, "default_checkpoint.ckpt")
301
- try:
302
- args = argparse.Namespace(
303
- infer_type=infer_type,
304
- test_image_path=uploaded_img,
305
- test_audio_path=uploaded_audio,
306
- test_hubert_path='',
307
- result_path='./outputs/',
308
- stage1_checkpoint_path='./ckpts/stage1.ckpt',
309
- stage2_checkpoint_path=stage2_checkpoint_path,
310
- seed=seed,
311
- control_flag=True,
312
- pose_yaw=pose_yaw,
313
- pose_pitch=pose_pitch,
314
- pose_roll=pose_roll,
315
- face_location=face_location,
316
- pose_driven_path='not_supported_in_this_mode',
317
- face_scale=face_scale,
318
- step_T=step_T,
319
- image_size=256,
320
- device=device,
321
- motion_dim=20,
322
- decoder_layers=2,
323
- face_sr=face_sr
324
- )
325
-
326
- # Save the uploaded audio to the expected path
327
- # shutil.copy(uploaded_audio, args.test_audio_path)
328
-
329
- # Run the main function
330
- output_256_video_path, output_512_video_path = main(args)
331
-
332
- # Check if the output video file exists
333
- if not os.path.exists(output_256_video_path):
334
- return None, gr.Markdown("Error: Video generation failed. Please check your inputs and try again.")
335
- if output_256_video_path == output_512_video_path:
336
- return gr.Video(value=output_256_video_path), None, gr.Markdown("Video (256*256 only) generated successfully!")
337
- return gr.Video(value=output_256_video_path), gr.Video(value=output_512_video_path), gr.Markdown("Video generated successfully!")
338
-
339
- except Exception as e:
340
- return None, None, gr.Markdown(f"Error: An unexpected error occurred - {str(e)}")
341
-
342
- default_values = {
343
- "pose_yaw": 0,
344
- "pose_pitch": 0,
345
- "pose_roll": 0,
346
- "face_location": 0.5,
347
- "face_scale": 0.5,
348
- "step_T": 50,
349
- "seed": 0,
350
- "device": "cuda"
351
- }
352
-
353
- with gr.Blocks() as demo:
354
- gr.Markdown('# AniTalker')
355
- gr.Markdown('![]()')
356
- with gr.Row():
357
- with gr.Column():
358
- uploaded_img = gr.Image(type="filepath", label="Reference Image")
359
- uploaded_audio = gr.Audio(type="filepath", label="Input Audio")
360
- with gr.Column():
361
- output_video_256 = gr.Video(label="Generated Video (256)")
362
- output_video_512 = gr.Video(label="Generated Video (512)")
363
- output_message = gr.Markdown()
364
-
365
-
366
-
367
- generate_button = gr.Button("Generate Video")
368
-
369
- with gr.Accordion("Configuration", open=True):
370
- infer_type = gr.Dropdown(
371
- label="Inference Type",
372
- choices=['mfcc_pose_only', 'mfcc_full_control', 'hubert_audio_only', 'hubert_pose_only'],
373
- value='hubert_audio_only'
374
- )
375
- face_sr = gr.Checkbox(label="Enable Face Super-Resolution (512*512)", value=False)
376
- # face_crop = gr.Checkbox(label="Face Crop (Dlib)", value=False)
377
- # face_crop = False # TODO
378
- seed = gr.Number(label="Seed", value=default_values["seed"])
379
- pose_yaw = gr.Slider(label="pose_yaw", minimum=-1, maximum=1, value=default_values["pose_yaw"])
380
- pose_pitch = gr.Slider(label="pose_pitch", minimum=-1, maximum=1, value=default_values["pose_pitch"])
381
- pose_roll = gr.Slider(label="pose_roll", minimum=-1, maximum=1, value=default_values["pose_roll"])
382
- face_location = gr.Slider(label="face_location", minimum=0, maximum=1, value=default_values["face_location"])
383
- face_scale = gr.Slider(label="face_scale", minimum=0, maximum=1, value=default_values["face_scale"])
384
- step_T = gr.Slider(label="step_T", minimum=1, maximum=100, step=1, value=default_values["step_T"])
385
- device = gr.Radio(label="Device", choices=["cuda", "cpu"], value=default_values["device"])
386
-
387
-
388
- generate_button.click(
389
- generate_video,
390
- inputs=[
391
- uploaded_img, uploaded_audio, infer_type,
392
- pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed
393
- ],
394
- outputs=[output_video_256, output_video_512, output_message]
395
- )
396
-
397
- if __name__ == '__main__':
398
- parser = argparse.ArgumentParser(description='EchoMimic')
399
- parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
400
- parser.add_argument('--server_port', type=int, default=3001, help='Server port')
401
- args = parser.parse_args()
402
-
403
- demo.launch()