Spaces:
Paused
Paused
File size: 5,590 Bytes
dacd25b 18358fb 7725ce2 dacd25b 2c7ebd6 7725ce2 c078b58 2b5109d 29a7230 c078b58 18358fb 2b5109d 18358fb 2b5109d 18358fb 7725ce2 18358fb 64a6a24 7725ce2 2b5109d dacd25b 7725ce2 dacd25b 7725ce2 dacd25b 2b5109d 18358fb dacd25b 7725ce2 dacd25b 2c7ebd6 d8d26ca 2c7ebd6 2b5109d 18358fb 7725ce2 18358fb 64a6a24 7725ce2 dacd25b 2b5109d dacd25b 7725ce2 b75a45c 2c7ebd6 7725ce2 64a6a24 2b5109d 18358fb 7725ce2 18358fb 64a6a24 7725ce2 64a6a24 7725ce2 5516eb1 b75a45c 9c8f4c5 18358fb 2b5109d 7725ce2 18358fb 7725ce2 18358fb 7725ce2 18358fb 7725ce2 18358fb 7725ce2 2c7ebd6 18358fb 1c8aab2 dacd25b f40229f 29a7230 5516eb1 7725ce2 18358fb 7725ce2 2c7ebd6 dacd25b 2b5109d 18358fb 2c7ebd6 c078b58 f6d3581 dacd25b 5516eb1 f6d3581 18358fb 7725ce2 f6d3581 dacd25b 18358fb c078b58 2b5109d 7725ce2 f6d3581 18358fb 64a6a24 dacd25b 18358fb dacd25b 18358fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
• Single global load (no repeated downloads)
• Balanced device_map to avoid OOM on 24 GB A10
• Fast CLIP processor via use_fast=True
• High-level streaming progress
• Auto-download via gr.File
"""
import os
# persist Hugging Face cache so safetensors only download once
os.environ["HF_HOME"] = "/mnt/data/huggingface"
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel
from PIL import Image
import torchvision.transforms.functional as TF
# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE = torch.float16
MAX_AREA = 1280 * 720
DEFAULT_FRAMES = 81
# -----------------------------------------------------------------------------
# LOAD PIPELINE ONCE
# -----------------------------------------------------------------------------
def load_pipeline():
# 1) CLIP image encoder (fp32)
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
# 2) VAE (fp16)
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
# 3) Balanced device placement + fast processor
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
image_encoder=image_encoder,
vae=vae,
torch_dtype=DTYPE,
device_map="balanced", # spread weights CPU↔GPU
use_fast=True, # internal fast CLIPImageProcessor
)
return pipe
PIPE = load_pipeline()
# -----------------------------------------------------------------------------
# HELPERS
# -----------------------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
ar = img.height / img.width
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = int(np.sqrt(max_area * ar)) // mod * mod
w = int(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h, w):
ratio = max(w / img.width, h / img.height)
img2 = img.resize(
(round(img.width * ratio), round(img.height * ratio)),
Image.LANCZOS
)
return TF.center_crop(img2, [h, w])
# -----------------------------------------------------------------------------
# GENERATION + STREAMING
# -----------------------------------------------------------------------------
def generate(
first_frame: Image.Image,
last_frame: Image.Image,
prompt: str,
negative: str,
steps: int,
guidance: float,
num_frames: int,
seed: int,
fps: int,
progress= gr.Progress(),
):
# choose seed
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
# 0–15%: resize
progress(0.0, desc="Resizing first frame…")
f_resized, h, w = aspect_resize(first_frame)
if last_frame.size != f_resized.size:
progress(0.15, desc="Resizing last frame…")
l_resized = center_crop_resize(last_frame, h, w)
else:
l_resized = f_resized
# 15–25%: spin up pipeline
progress(0.25, desc="Launching inference…")
out = PIPE(
image=f_resized,
last_image=l_resized,
prompt=prompt,
negative_prompt=negative or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
# 90–100%: export
progress(0.90, desc="Building video file…")
video_path = export_to_video(out.frames[0], fps=fps)
progress(1.0, desc="Done!")
return video_path, seed
# -----------------------------------------------------------------------------
# GRADIO UI
# -----------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative = gr.Textbox(label="Negative prompt (opt)", placeholder="blurry, lowres")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
seed_input = gr.Number(value=-1, precision=0, label="Seed (-1=rand)")
run_btn = gr.Button("Generate")
download = gr.File(label="Download .mp4", interactive=False)
seed_used = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[ first_img, last_img, prompt, negative,
steps, guidance, num_frames, seed_input, fps ],
outputs=[ download, seed_used ],
)
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |