File size: 5,590 Bytes
dacd25b
 
18358fb
7725ce2
 
 
 
 
dacd25b
2c7ebd6
7725ce2
 
 
c078b58
2b5109d
29a7230
c078b58
 
18358fb
2b5109d
 
 
18358fb
2b5109d
18358fb
 
 
 
 
 
 
7725ce2
18358fb
64a6a24
7725ce2
2b5109d
dacd25b
 
7725ce2
dacd25b
 
 
7725ce2
dacd25b
 
2b5109d
18358fb
dacd25b
7725ce2
 
dacd25b
2c7ebd6
 
d8d26ca
2c7ebd6
2b5109d
18358fb
7725ce2
18358fb
64a6a24
7725ce2
 
 
 
dacd25b
 
2b5109d
dacd25b
7725ce2
b75a45c
 
2c7ebd6
7725ce2
64a6a24
2b5109d
18358fb
7725ce2
18358fb
64a6a24
7725ce2
 
 
 
 
 
 
 
 
 
64a6a24
7725ce2
5516eb1
 
b75a45c
9c8f4c5
18358fb
2b5109d
7725ce2
 
18358fb
7725ce2
18358fb
7725ce2
18358fb
7725ce2
 
18358fb
7725ce2
 
2c7ebd6
18358fb
1c8aab2
 
dacd25b
 
 
f40229f
29a7230
5516eb1
7725ce2
 
18358fb
 
7725ce2
2c7ebd6
dacd25b
2b5109d
18358fb
 
 
2c7ebd6
c078b58
f6d3581
dacd25b
 
5516eb1
f6d3581
18358fb
7725ce2
f6d3581
dacd25b
18358fb
 
c078b58
2b5109d
7725ce2
f6d3581
18358fb
 
 
64a6a24
 
dacd25b
18358fb
 
 
dacd25b
 
18358fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
• Single global load (no repeated downloads)
• Balanced device_map to avoid OOM on 24 GB A10
• Fast CLIP processor via use_fast=True
• High-level streaming progress
• Auto-download via gr.File
"""
import os
# persist Hugging Face cache so safetensors only download once
os.environ["HF_HOME"] = "/mnt/data/huggingface"

import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel
from PIL import Image
import torchvision.transforms.functional as TF

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
MODEL_ID       = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE          = torch.float16
MAX_AREA       = 1280 * 720
DEFAULT_FRAMES = 81

# -----------------------------------------------------------------------------
# LOAD PIPELINE ONCE
# -----------------------------------------------------------------------------
def load_pipeline():
    # 1) CLIP image encoder (fp32)
    image_encoder = CLIPVisionModel.from_pretrained(
        MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
    )
    # 2) VAE (fp16)
    vae = AutoencoderKLWan.from_pretrained(
        MODEL_ID, subfolder="vae", torch_dtype=DTYPE
    )
    # 3) Balanced device placement + fast processor
    pipe = WanImageToVideoPipeline.from_pretrained(
        MODEL_ID,
        image_encoder=image_encoder,
        vae=vae,
        torch_dtype=DTYPE,
        device_map="balanced",  # spread weights CPU↔GPU
        use_fast=True,          # internal fast CLIPImageProcessor
    )
    return pipe

PIPE = load_pipeline()


# -----------------------------------------------------------------------------
# HELPERS
# -----------------------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
    ar  = img.height / img.width
    mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
    h   = int(np.sqrt(max_area * ar)) // mod * mod
    w   = int(np.sqrt(max_area / ar)) // mod * mod
    return img.resize((w, h), Image.LANCZOS), h, w

def center_crop_resize(img: Image.Image, h, w):
    ratio = max(w / img.width, h / img.height)
    img2 = img.resize(
        (round(img.width * ratio), round(img.height * ratio)),
        Image.LANCZOS
    )
    return TF.center_crop(img2, [h, w])


# -----------------------------------------------------------------------------
# GENERATION + STREAMING
# -----------------------------------------------------------------------------
def generate(
    first_frame:   Image.Image,
    last_frame:    Image.Image,
    prompt:        str,
    negative:      str,
    steps:         int,
    guidance:      float,
    num_frames:    int,
    seed:          int,
    fps:           int,
    progress=      gr.Progress(),
):
    # choose seed
    if seed == -1:
        seed = torch.seed()
    gen = torch.Generator(device=PIPE.device).manual_seed(seed)

    # 0–15%: resize
    progress(0.0, desc="Resizing first frame…")
    f_resized, h, w = aspect_resize(first_frame)
    if last_frame.size != f_resized.size:
        progress(0.15, desc="Resizing last frame…")
        l_resized = center_crop_resize(last_frame, h, w)
    else:
        l_resized = f_resized

    # 15–25%: spin up pipeline
    progress(0.25, desc="Launching inference…")
    out = PIPE(
        image=f_resized,
        last_image=l_resized,
        prompt=prompt,
        negative_prompt=negative or None,
        height=h,
        width=w,
        num_frames=num_frames,
        num_inference_steps=steps,
        guidance_scale=guidance,
        generator=gen,
    )

    # 90–100%: export
    progress(0.90, desc="Building video file…")
    video_path = export_to_video(out.frames[0], fps=fps)
    progress(1.0, desc="Done!")

    return video_path, seed


# -----------------------------------------------------------------------------
# GRADIO UI
# -----------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")

    with gr.Row():
        first_img = gr.Image(label="First frame", type="pil")
        last_img  = gr.Image(label="Last frame",  type="pil")

    prompt   = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
    negative = gr.Textbox(label="Negative prompt (opt)", placeholder="blurry, lowres")

    with gr.Accordion("Advanced parameters", open=False):
        steps      = gr.Slider(10, 50, value=30, step=1, label="Steps")
        guidance   = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
        num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
        fps        = gr.Slider(4, 30, value=16, step=1, label="FPS")
        seed_input = gr.Number(value=-1, precision=0, label="Seed (-1=rand)")

    run_btn   = gr.Button("Generate")
    download  = gr.File(label="Download .mp4", interactive=False)
    seed_used = gr.Number(label="Seed used", interactive=False)

    run_btn.click(
        fn=generate,
        inputs=[ first_img, last_img, prompt, negative,
                 steps, guidance, num_frames, seed_input, fps ],
        outputs=[ download, seed_used ],
    )

    demo.queue().launch(server_name="0.0.0.0", server_port=7860)