|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import torchvision.transforms.functional as TF |
|
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline |
|
from diffusers.utils import export_to_video, load_image |
|
from transformers import CLIPVisionModel |
|
|
|
def generate_video(first_frame_url, last_frame_url, prompt): |
|
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" |
|
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) |
|
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) |
|
pipe = WanImageToVideoPipeline.from_pretrained( |
|
"Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers", |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
keep_in_fp32_modules=True |
|
) |
|
pipe.to("cuda") |
|
|
|
first_frame = load_image(first_frame_url) |
|
last_frame = load_image(last_frame_url) |
|
|
|
def aspect_ratio_resize(image, pipe, max_area=720 * 1280): |
|
aspect_ratio = image.height / image.width |
|
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] |
|
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value |
|
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value |
|
image = image.resize((width, height)) |
|
return image, height, width |
|
|
|
def center_crop_resize(image, height, width): |
|
resize_ratio = max(width / image.width, height / image.height) |
|
width = round(image.width * resize_ratio) |
|
height = round(image.height * resize_ratio) |
|
size = [width, height] |
|
image = TF.center_crop(image, size) |
|
return image, height, width |
|
|
|
first_frame, height, width = aspect_ratio_resize(first_frame, pipe) |
|
if last_frame.size != first_frame.size: |
|
last_frame, _, _ = center_crop_resize(last_frame, height, width) |
|
|
|
output = pipe( |
|
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5 |
|
).frames[0] |
|
video_path = "wan_output.mp4" |
|
export_to_video(output, video_path, fps=16) |
|
return video_path |
|
|
|
iface = gr.Interface( |
|
fn=generate_video, |
|
inputs=[ |
|
gr.Textbox(label="First Frame URL"), |
|
gr.Textbox(label="Last Frame URL"), |
|
gr.Textbox(label="Prompt") |
|
], |
|
outputs=gr.Video(label="Generated Video"), |
|
title="Wan2.1 FLF2V Video Generator" |
|
) |
|
|
|
iface.launch() |
|
|
|
|