GeradeHouse commited on
Commit
64a6a24
·
verified ·
1 Parent(s): 1c8aab2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -101
app.py CHANGED
@@ -1,10 +1,15 @@
1
  #!/usr/bin/env python
2
  """
3
- Gradio demo for Wan2.1 FLF2V – full streaming progress
4
- No globals: pipeline, resize utils all use the local `pipe`.
5
- Author: <your-handle>
 
6
  """
7
 
 
 
 
 
8
  import numpy as np
9
  import torch
10
  import gradio as gr
@@ -12,95 +17,108 @@ from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
12
  from diffusers.utils import export_to_video
13
  from transformers import CLIPVisionModel, CLIPImageProcessor
14
  from PIL import Image
15
- import torchvision.transforms.functional as TF
16
-
17
- # ---------------------------------------------------------------------
18
- # CONFIG ----------------------------------------------------------------
19
- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
20
- DTYPE = torch.float16
21
- MAX_AREA = 1280 * 720
22
- DEFAULT_FRAMES = 81
23
- # ----------------------------------------------------------------------
24
 
25
- def load_pipeline(progress):
26
- """Load & shard the pipeline across CPU/GPU with streaming progress."""
27
- progress(0.00, desc="Init: loading image encoder…")
28
- image_encoder = CLIPVisionModel.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
30
  )
31
- progress(0.10, desc="Loaded image encoder")
32
-
33
- progress(0.10, desc="Loading VAE…")
 
 
34
  vae = AutoencoderKLWan.from_pretrained(
35
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
36
  )
37
- progress(0.20, desc="Loaded VAE")
38
-
39
- progress(0.20, desc="Assembling pipeline…")
40
  pipe = WanImageToVideoPipeline.from_pretrained(
41
  MODEL_ID,
42
  vae=vae,
43
- image_encoder=image_encoder,
 
44
  torch_dtype=DTYPE,
45
- low_cpu_mem_usage=True,
46
- device_map="balanced",
47
  )
48
- progress(0.30, desc="Pipeline assembled")
49
-
50
- progress(0.30, desc="Loading fast image processor…")
51
- pipe.image_processor = CLIPImageProcessor.from_pretrained(
52
- MODEL_ID, subfolder="image_processor", use_fast=True
53
- )
54
- progress(0.40, desc="Processor ready")
55
-
56
- return pipe
57
-
58
- def aspect_resize(img: Image.Image, pipe, max_area=MAX_AREA):
59
- """Resize while respecting model patch multiples, using `pipe` for scale."""
60
  ar = img.height / img.width
61
- mod = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
62
- h = round(np.sqrt(max_area * ar)) // mod * mod
63
- w = round(np.sqrt(max_area / ar)) // mod * mod
64
  return img.resize((w, h), Image.LANCZOS), h, w
65
 
66
- def center_crop_resize(img: Image.Image, pipe, h, w):
67
- """Center-crop & resize to H×W, using same Lanczos filter."""
 
 
68
  ratio = max(w / img.width, h / img.height)
69
- img = img.resize(
70
- (round(img.width * ratio), round(img.height * ratio)),
71
- Image.LANCZOS
72
- )
73
- return TF.center_crop(img, [h, w])
74
-
75
- def generate(first_frame, last_frame, prompt, negative_prompt,
76
- steps, guidance, num_frames, seed, fps,
77
- progress=gr.Progress()): # Gradio progress hook
78
-
79
- # 1) Load & shard pipeline
80
- pipe = load_pipeline(progress)
81
-
82
- # 2) Preprocess
83
- progress(0.45, desc="Preprocessing first frame…")
84
- first_frame, h, w = aspect_resize(first_frame, pipe)
85
- if last_frame.size != first_frame.size:
86
- progress(0.50, desc="Preprocessing last frame…")
87
- last_frame = center_crop_resize(last_frame, pipe, h, w)
88
- progress(0.55, desc="Frames ready")
89
-
90
- # 3) Run inference with per-step callbacks
 
 
 
91
  if seed == -1:
92
  seed = torch.seed()
93
- gen = torch.Generator(device=pipe.device).manual_seed(seed)
94
 
95
- def _cb(step, timestep, latents):
96
- frac = 0.55 + 0.35 * ((step + 1) / steps)
97
- progress(frac, desc=f"Inference step {step+1}/{steps}")
 
 
98
 
99
- progress(0.55, desc="Starting inference…")
100
- output = pipe(
101
- image=first_frame,
102
  last_image=last_frame,
103
- prompt=prompt,
104
  negative_prompt=negative_prompt or None,
105
  height=h,
106
  width=w,
@@ -108,44 +126,42 @@ def generate(first_frame, last_frame, prompt, negative_prompt,
108
  num_inference_steps=steps,
109
  guidance_scale=guidance,
110
  generator=gen,
111
- callback_on_step_end=_cb,
112
- callback_steps=1,
113
  )
114
- frames = output.frames[0]
115
-
116
- # 4) Export video
117
- progress(0.92, desc="Exporting video…")
118
- video_path = export_to_video(frames, fps=fps)
119
 
120
- # 5) Done
121
- progress(1.0, desc="Complete!")
122
- return video_path
 
123
 
 
 
 
124
  with gr.Blocks() as demo:
125
- gr.Markdown("## Wan2.1 FLF2V – Full Streaming Progress")
126
-
127
  with gr.Row():
128
  first_img = gr.Image(label="First frame", type="pil")
129
  last_img = gr.Image(label="Last frame", type="pil")
130
-
131
- prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
132
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
133
-
134
  with gr.Accordion("Advanced parameters", open=False):
135
- steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
136
- guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
137
- num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, label="Frames")
138
- fps = gr.Slider(4, 30, value=16, label="FPS")
139
- seed = gr.Number(value=-1, precision=0, label="Seed")
140
-
141
- video = gr.Video(label="Result (.mp4)")
142
-
143
- btn = gr.Button("Generate")
144
- btn.click(
 
 
145
  fn=generate,
146
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
147
- outputs=[video],
148
  )
149
 
150
- demo.queue() # enable streaming updates
151
- demo.launch()
 
1
  #!/usr/bin/env python
2
  """
3
+ Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
+ Loads the huge model lazily (only once), streams **all** tqdm bars
5
+ (from HF downloads, shard loading, to denoising) into Gradio's UI,
6
+ and outputs a direct File download for the generated video.
7
  """
8
 
9
+ import os
10
+ import tempfile
11
+
12
+ import ftfy
13
  import numpy as np
14
  import torch
15
  import gradio as gr
 
17
  from diffusers.utils import export_to_video
18
  from transformers import CLIPVisionModel, CLIPImageProcessor
19
  from PIL import Image
 
 
 
 
 
 
 
 
 
20
 
21
+ # -----------------------------------------------------------------------------
22
+ # CONFIG
23
+ # -----------------------------------------------------------------------------
24
+ MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
25
+ DTYPE = torch.float16 # or torch.bfloat16 on AMP-friendly cards
26
+ MAX_AREA = 1280 * 720 # ≤720p
27
+ DEFAULT_FRAMES = 81 # ~5s @16fps
28
+
29
+ # -----------------------------------------------------------------------------
30
+ # GLOBAL PIPELINE (lazy)
31
+ # -----------------------------------------------------------------------------
32
+ PIPE = None
33
+
34
+ def load_pipeline():
35
+ """
36
+ Load the Wan2.1-FLF2V pipeline once, with fast processor,
37
+ CPU-offload for large models, and in half-precision.
38
+ """
39
+ # 1) full-precision CLIP encoder
40
+ vision = CLIPVisionModel.from_pretrained(
41
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
42
  )
43
+ # 2) fast CLIP image processor
44
+ processor = CLIPImageProcessor.from_pretrained(
45
+ MODEL_ID, subfolder="preprocessor", use_fast=True
46
+ )
47
+ # 3) reduced-precision VAE
48
  vae = AutoencoderKLWan.from_pretrained(
49
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
50
  )
51
+ # 4) assemble pipeline
 
 
52
  pipe = WanImageToVideoPipeline.from_pretrained(
53
  MODEL_ID,
54
  vae=vae,
55
+ image_encoder=vision,
56
+ image_processor=processor,
57
  torch_dtype=DTYPE,
 
 
58
  )
59
+ # 5) offload to CPU/AutoDevice
60
+ pipe.enable_model_cpu_offload()
61
+ # (we drop .enable_slicing() because it's unsupported here)
62
+ return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
+ # -----------------------------------------------------------------------------
65
+ # UTILS
66
+ # -----------------------------------------------------------------------------
67
+ def aspect_resize(img: Image.Image, max_area=MAX_AREA):
68
+ """
69
+ Resize while respecting the model's patch size (multiple of 8 * transformer patch).
70
+ """
71
  ar = img.height / img.width
72
+ mod = PIPE.transformer.config.patch_size[1] * PIPE.vae_scale_factor_spatial
73
+ h = (int(np.sqrt(max_area * ar)) // mod) * mod
74
+ w = (int(np.sqrt(max_area / ar)) // mod) * mod
75
  return img.resize((w, h), Image.LANCZOS), h, w
76
 
77
+ def center_crop_resize(img: Image.Image, h: int, w: int):
78
+ """
79
+ Center-crop + resize to exactly h×w.
80
+ """
81
  ratio = max(w / img.width, h / img.height)
82
+ img2 = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
83
+ return TF.center_crop(img2, [h, w])
84
+
85
+ # -----------------------------------------------------------------------------
86
+ # GENERATION (with full tqdm → Gradio progress streaming)
87
+ # -----------------------------------------------------------------------------
88
+ def generate(
89
+ first_frame: Image.Image,
90
+ last_frame: Image.Image,
91
+ prompt: str,
92
+ negative_prompt: str,
93
+ steps: int,
94
+ guidance: float,
95
+ num_frames: int,
96
+ seed: int,
97
+ fps: int,
98
+ progress=gr.Progress(track_tqdm=True),
99
+ ):
100
+ global PIPE
101
+ # lazy instantiate
102
+ if PIPE is None:
103
+ progress(0, desc="Loading pipeline…")
104
+ PIPE = load_pipeline()
105
+
106
+ # seeding
107
  if seed == -1:
108
  seed = torch.seed()
109
+ gen = torch.Generator(device=PIPE.device).manual_seed(seed)
110
 
111
+ # preprocess
112
+ progress(0, desc="Preprocessing…")
113
+ frame1, h, w = aspect_resize(first_frame)
114
+ if last_frame.size != frame1.size:
115
+ last_frame = center_crop_resize(last_frame, h, w)
116
 
117
+ # inference (all tqdm inside will stream to UI)
118
+ result = PIPE(
119
+ image=frame1,
120
  last_image=last_frame,
121
+ prompt=whitespace_clean(basic_clean(prompt)),
122
  negative_prompt=negative_prompt or None,
123
  height=h,
124
  width=w,
 
126
  num_inference_steps=steps,
127
  guidance_scale=guidance,
128
  generator=gen,
129
+ # no callback_steps here!
 
130
  )
131
+ frames = result.frames[0] # list of PIL images
 
 
 
 
132
 
133
+ # export to MP4
134
+ progress(1.0, desc="Assembling video…")
135
+ out_path = export_to_video(frames, fps=fps)
136
+ return out_path, seed
137
 
138
+ # -----------------------------------------------------------------------------
139
+ # BUILD UI
140
+ # -----------------------------------------------------------------------------
141
  with gr.Blocks() as demo:
142
+ gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video (Diffusers)")
 
143
  with gr.Row():
144
  first_img = gr.Image(label="First frame", type="pil")
145
  last_img = gr.Image(label="Last frame", type="pil")
146
+ prompt = gr.Textbox(label="Prompt", placeholder="A small blue bird takes off…")
147
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
 
 
148
  with gr.Accordion("Advanced parameters", open=False):
149
+ steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
150
+ guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
151
+ num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
152
+ fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
153
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
154
+ run_btn = gr.Button("Generate")
155
+ # **File** component for direct download link:
156
+ download = gr.File(label="Download video (.mp4)")
157
+ used_seed = gr.Number(label="Seed used", interactive=False)
158
+
159
+ # queue() for async + progress
160
+ run_btn.click(
161
  fn=generate,
162
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
163
+ outputs=[download, used_seed],
164
  )
165
 
166
+ # MUST call .queue() to enable gr.Progress()
167
+ demo.queue(concurrency_count=1).launch()