import os os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html") import shutil import math from huggingface_hub import snapshot_download os.makedirs("pretrained_models", exist_ok=True) snapshot_download( repo_id="multimodalart/diffposetalk", local_dir="pretrained_models/diffposetalk" ) base_dir = "pretrained_models" os.makedirs(base_dir, exist_ok=True) # Download FLAME, mediapipe, and smirk for model in ["FLAME", "mediapipe", "smirk"]: # Download to a temp folder first temp_dir = f"{base_dir}/{model}_temp" snapshot_download( repo_id="Skywork/SkyReels-A1", local_dir=temp_dir, allow_patterns=f"extra_models/{model}/**" ) # Move files from nested extra_models/model to the proper location src_dir = f"{temp_dir}/extra_models/{model}" dst_dir = f"{base_dir}/{model}" os.makedirs(dst_dir, exist_ok=True) # Move all contents for item in os.listdir(src_dir): shutil.move(f"{src_dir}/{item}", f"{dst_dir}/{item}") # Clean up temp directory shutil.rmtree(temp_dir) # Download SkyReels-A1-5B snapshot_download( repo_id="Skywork/SkyReels-A1", local_dir=f"{base_dir}/SkyReels-A1-5B", ) import gradio as gr import torch import numpy as np from PIL import Image import cv2 import gc import tempfile import moviepy.editor as mp from facexlib.utils.face_restoration_helper import FaceRestoreHelper from diffusers.utils import export_to_video, load_image # Import required modules from SkyReels from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d from diffusers.models import AutoencoderKLCogVideoX from transformers import SiglipImageProcessor, SiglipVisionModel from diffposetalk.diffposetalk import DiffPoseTalk # Helper functions from the original script def parse_video(driving_frames, max_frame_num, fps=25): video_length = len(driving_frames) duration = video_length / fps target_times = np.arange(0, duration, 1/12) frame_indices = (target_times * fps).astype(np.int32) frame_indices = frame_indices[frame_indices < video_length] new_driving_frames = [] for idx in frame_indices: new_driving_frames.append(driving_frames[idx]) if len(new_driving_frames) >= max_frame_num - 1: break video_lenght_add = max_frame_num - len(new_driving_frames) - 1 new_driving_frames = [new_driving_frames[0]]*2 + new_driving_frames[1:len(new_driving_frames)-1] + [new_driving_frames[-1]] * video_lenght_add return new_driving_frames def write_mp4(video_path, samples, fps=12): clip = mp.ImageSequenceClip(samples, fps=fps) clip.write_videofile(video_path, audio_codec="aac", codec="libx264", ffmpeg_params=["-crf", "18", "-preset", "slow"]) def save_video_with_audio(video_path, audio_path, save_path): video_clip = mp.VideoFileClip(video_path) audio_clip = mp.AudioFileClip(audio_path) if audio_clip.duration > video_clip.duration: audio_clip = audio_clip.subclip(0, video_clip.duration) video_with_audio = video_clip.set_audio(audio_clip) video_with_audio.write_videofile(save_path, fps=12, codec="libx264", audio_codec="aac") # Clean up video_clip.close() audio_clip.close() return save_path def pad_video(driving_frames, fps=25): video_length = len(driving_frames) duration = video_length / fps target_times = np.arange(0, duration, 1/12) frame_indices = (target_times * fps).astype(np.int32) frame_indices = frame_indices[frame_indices < video_length] new_driving_frames = [] for idx in frame_indices: new_driving_frames.append(driving_frames[idx]) pad_length = math.ceil(len(new_driving_frames) / 48) * 48 - len(new_driving_frames) new_driving_frames.extend([new_driving_frames[-1]]*pad_length) return new_driving_frames, pad_length # Global parameters model_name = "pretrained_models/SkyReels-A1-5B/" siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384" weight_dtype = torch.bfloat16 max_frame_num = 49 sample_size = [480, 720] # Preload all models in global context print("Loading models...") # Load LMK extractor and processors lmk_extractor = LMKExtractor() processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False) face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda") # Load siglip visual encoder siglip = SiglipVisionModel.from_pretrained(siglip_name) siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name) # Load diffposetalk diffposetalk = DiffPoseTalk() # Load SkyReels models transformer = CogVideoXTransformer3DModel.from_pretrained( model_name, subfolder="transformer" ).to(weight_dtype) vae = AutoencoderKLCogVideoX.from_pretrained( model_name, subfolder="vae" ).to(weight_dtype) lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( model_name, subfolder="pose_guider", ).to(weight_dtype) # Set up pipeline pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( model_name, transformer=transformer, vae=vae, lmk_encoder=lmk_encoder, image_encoder=siglip, feature_extractor=siglip_normalize, torch_dtype=torch.bfloat16 ) pipe.to("cuda") pipe.transformer = torch.compile(pipe.transformer) pipe.vae.enable_tiling() pipe.vae = torch.compile(pipe.vae) # pipe.enable_model_cpu_offload() print("Models loaded successfully!") def process_image_audio(image_path, audio_path, guidance_scale=3.0, steps=10, progress=gr.Progress()): progress(0.1, desc="Processing inputs...") # Create a directory for outputs if it doesn't exist output_dir = "gradio_outputs" os.makedirs(output_dir, exist_ok=True) # Create temp files for processing with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file, \ tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_output_file: temp_video_path = temp_video_file.name final_output_path = temp_output_file.name # Set seed # seed = 43 # generator = torch.Generator(device="cuda").manual_seed(seed) progress(0.2, desc="Processing image...") # Load and process image image = load_image(image=image_path) image = processor.crop_and_resize(image, sample_size[0], sample_size[1]) # Crop face ref_image, x1, y1 = processor.face_crop(np.array(image)) face_h, face_w, _ = ref_image.shape source_image = ref_image progress(0.3, desc="Processing facial landmarks...") # Process source image source_outputs, source_tform, image_original = processor.process_source_image(source_image) progress(0.4, desc="Processing audio...") # Process audio and generate driving outputs driving_outputs = diffposetalk.infer_from_file( audio_path, source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy() ) progress(0.5, desc="Processing landmarks from coefficients...") # Process landmarks out_frames = processor.preprocess_lmk3d_from_coef( source_outputs, source_tform, image_original.shape, driving_outputs ) out_frames, pad_length = pad_video(out_frames) print(len(out_frames), pad_length) # out_frames = parse_video(out_frames, max_frame_num) rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(len(out_frames), axis=0) for ii in range(rescale_motions.shape[0]): rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] ref_image_resized = cv2.resize(ref_image, (512, 512)) ref_lmk = lmk_extractor(ref_image_resized[:, :, ::-1]) ref_img = vis.draw_landmarks_v3( (512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True ) first_motion = np.zeros_like(np.array(image)) first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img first_motion = first_motion[np.newaxis, :] # motions = np.concatenate([first_motion, rescale_motions]) # input_video = motions[:max_frame_num] # Face alignment face_helper.clean_all() face_helper.read_image(np.array(image)[:, :, ::-1]) face_helper.get_face_landmarks_5(only_center_face=True) face_helper.align_warp_face() align_face = face_helper.cropped_faces[0] image_face = align_face[:, :, ::-1] # Prepare input video # input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) # input_video = input_video / 255 progress(0.6, desc="Generating animation (this may take a while)...") # Generate video out_samples = [] for i in range(0, len(rescale_motions), 48): motions = np.concatenate([first_motion, rescale_motions[i:i+48]]) input_video = motions input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) input_video = input_video / 255 with torch.no_grad(): sample = pipe( image=image, image_face=image_face, control_video=input_video, prompt="", negative_prompt="", height=480, width=720, num_frames=49, # generator=generator, guidance_scale=guidance_scale, num_inference_steps=steps, ) if i == 0: out_samples.extend(sample.frames[0]) else: out_samples.extend(sample.frames[0][1:]) # out_samples = sample.frames[0] # out_samples = out_samples[2:] # Skip first two frames if pad_length == 0: out_samples = out_samples[1:] else: out_samples = out_samples[1:-pad_length] progress(0.8, desc="Creating output video...") # Export video export_to_video(out_samples, temp_video_path, fps=12) progress(0.9, desc="Adding audio to video...") # Add audio to video result_path = save_video_with_audio(temp_video_path, audio_path, final_output_path) # Create side-by-side comparison target_h, target_w = sample_size[0], sample_size[1] final_images = [] for i in range(len(out_samples)): frame1 = image frame2 = Image.fromarray(np.array(out_samples[i])).convert("RGB") result = Image.new('RGB', (target_w * 2, target_h)) result.paste(frame1, (0, 0)) result.paste(frame2, (target_w, 0)) final_images.append(np.array(result)) comparison_path = os.path.join(output_dir, "comparison.mp4") write_mp4(comparison_path, final_images, fps=12) # Add audio to comparison video comparison_with_audio = os.path.join(output_dir, "comparison_with_audio.mp4") comparison_with_audio = save_video_with_audio(comparison_path, audio_path, comparison_with_audio) progress(1.0, desc="Done!") torch.cuda.empty_cache() gc.collect() return result_path, comparison_with_audio # Create Gradio interface with gr.Blocks(title="SkyReels A1 Talking Head") as app: gr.Markdown("# SkyReels A1 Talking Head") gr.Markdown('''Upload a portrait image and an audio file to animate the face. 💡 Enjoying this demo? Share your feedback or review, and you might earn exclusive rewards! 🚀✨ 📩 [Contact us on Discord](https://discord.com/invite/PwM6NYtccQ) for details. 🔥 [Code](https://github.com/SkyworkAI/SkyReels-A1) [Huggingface](https://huggingface.co/Skywork/SkyReels-A1)''') with gr.Row(): with gr.Column(): with gr.Row(): image_input = gr.Image(type="filepath", label="Portrait Image") audio_input = gr.Audio(type="filepath", label="Driving Audio") with gr.Row(): guidance_scale = gr.Slider(minimum=1.0, maximum=7.0, value=3.0, step=0.1, label="Guidance Scale") inference_steps = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Inference Steps") generate_button = gr.Button("Generate Animation", variant="primary") with gr.Column(): output_video = gr.Video(label="Animation Result") comparison_video = gr.Video(label="Side-by-Side Comparison") generate_button.click( fn=process_image_audio, inputs=[image_input, audio_input, guidance_scale, inference_steps], outputs=[output_video, comparison_video] ) gr.Markdown(""" ## Instructions 1. Upload a portrait image (frontal face works best) 2. Upload an audio file (wav format recommended) 3. Adjust parameters if needed 4. Click "Generate Animation" to create the video Note: Processing may take several minutes depending on your hardware. """) if __name__ == "__main__": app.launch(share=True)