GeradeHouse commited on
Commit
2b5109d
·
verified ·
1 Parent(s): d8d26ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -89
app.py CHANGED
@@ -1,134 +1,127 @@
1
  #!/usr/bin/env python
2
  """
3
- Gradio demo for Wan2.1-FLF2V First & Last Frame → Video
 
 
4
  """
5
-
6
  import os
7
- # Persist HF cache between launches
8
- os.environ["HF_HOME"] = "/mnt/data/huggingface"
9
-
10
- import torch
11
  import numpy as np
 
12
  import gradio as gr
13
- from PIL import Image
14
- import torchvision.transforms.functional as TF
15
- from transformers import CLIPVisionModel, CLIPImageProcessor
16
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
17
  from diffusers.utils import export_to_video
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # -----------------------------------------------------------------------------
20
- # CONFIGURATION
21
- # -----------------------------------------------------------------------------
22
- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
23
- DTYPE = torch.float16
24
- MAX_AREA = 1280 * 720
25
- DEFAULT_FRAMES = 81
26
-
27
- # -----------------------------------------------------------------------------
28
- # PIPELINE LOADING (ONCE)
29
- # -----------------------------------------------------------------------------
30
  def load_pipeline():
31
- # 1) Vision encoder (fp32)
32
- clip_encoder = CLIPVisionModel.from_pretrained(
33
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
34
  )
35
- # 2) VAE (reduced precision)
36
  vae = AutoencoderKLWan.from_pretrained(
37
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
38
  )
39
- # 3) CLIPImageProcessor (exactly the type Wan expects)
40
- img_processor = CLIPImageProcessor.from_pretrained(
41
- "openai/clip-vit-base-patch32", use_fast=True
42
  )
43
- # 4) Load the Wan‐to‐Video pipeline, balanced across GPU & CPU
44
  pipe = WanImageToVideoPipeline.from_pretrained(
45
  MODEL_ID,
46
- image_encoder=clip_encoder,
47
  vae=vae,
48
- image_processor=img_processor,
 
49
  torch_dtype=DTYPE,
50
- device_map="balanced",
51
  )
52
- # 5) Slice the VAE to cut VRAM spikes
53
- try:
54
- pipe.vae.enable_slicing()
55
- except AttributeError:
56
- pass
57
  return pipe
58
 
59
- # instantiate once
60
  PIPE = load_pipeline()
61
 
62
- # -----------------------------------------------------------------------------
63
- # IMAGE RESIZE HELPERS
64
- # -----------------------------------------------------------------------------
65
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
66
- ar = img.height / img.width
67
- mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
68
- h = int(np.sqrt(max_area * ar)) // mod * mod
69
- w = int(np.sqrt(max_area / ar)) // mod * mod
 
70
  return img.resize((w, h), Image.LANCZOS), h, w
71
 
72
- def center_crop_resize(img: Image.Image, h: int, w: int):
 
73
  ratio = max(w / img.width, h / img.height)
74
- img = img.resize(
75
  (round(img.width * ratio), round(img.height * ratio)),
76
  Image.LANCZOS
77
  )
78
  return TF.center_crop(img, [h, w])
79
 
80
- # -----------------------------------------------------------------------------
81
- # GENERATION (STREAMING)
82
- # -----------------------------------------------------------------------------
83
  def generate(
84
  first_frame: Image.Image,
85
- last_frame: Image.Image,
86
- prompt: str,
87
- negative: str,
88
- steps: int,
89
- guidance: float,
90
- num_frames: int,
91
- seed: int,
92
- fps: int,
93
- progress= gr.Progress()
94
  ):
95
- # Seed management
96
  if seed == -1:
97
  seed = torch.seed()
98
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
99
 
100
- # Preprocessing update
101
- progress(0, steps, desc="Preprocessing images")
102
- f0, h, w = aspect_resize(first_frame)
103
- if last_frame.size != f0.size:
 
104
  last_frame = center_crop_resize(last_frame, h, w)
105
 
106
- # Step callback
107
- def cb(step, timestep, latents):
108
- progress(step, steps, desc=f"Inference step {step}/{steps}")
109
-
110
- # Run the pipeline
111
- out = PIPE(
112
- image=f0,
113
  last_image=last_frame,
114
  prompt=prompt,
115
- negative_prompt=negative or None,
116
  height=h,
117
  width=w,
118
  num_frames=num_frames,
119
  num_inference_steps=steps,
120
  guidance_scale=guidance,
121
  generator=gen,
122
- callback=cb
123
  )
124
 
125
- # Export video
126
- video_path = export_to_video(out.frames[0], fps=fps)
 
 
 
 
127
  return video_path, seed
128
 
129
- # -----------------------------------------------------------------------------
130
- # GRADIO APP
131
- # -----------------------------------------------------------------------------
132
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
  gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
134
 
@@ -136,25 +129,26 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
136
  first_img = gr.Image(label="First frame", type="pil")
137
  last_img = gr.Image(label="Last frame", type="pil")
138
 
139
- prompt_box = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
140
- negative_box = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
141
 
142
  with gr.Accordion("Advanced parameters", open=False):
143
- steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
144
- guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
145
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
146
- fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
147
- seed_input = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
148
 
149
- run_btn = gr.Button("Generate")
150
- video_out = gr.Video(label="Result (.mp4)")
151
- seed_out = gr.Number(label="Seed used", interactive=False)
152
 
153
  run_btn.click(
154
  fn=generate,
155
- inputs=[ first_img, last_img, prompt_box, negative_box,
156
- steps, guidance, num_frames, seed_input, fps ],
157
- outputs=[ video_out, seed_out ]
158
  )
159
 
160
- demo.launch()
 
 
1
  #!/usr/bin/env python
2
  """
3
+ Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
+ Loads the huge model once, uses balanced device placement,
5
+ streams high-level progress, and auto-offers the .mp4 for download.
6
  """
 
7
  import os
 
 
 
 
8
  import numpy as np
9
+ import torch
10
  import gradio as gr
 
 
 
11
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
12
  from diffusers.utils import export_to_video
13
+ from transformers import CLIPImageProcessor, CLIPVisionModel
14
+ from PIL import Image
15
+ import torchvision.transforms.functional as TF
16
+
17
+ # --------------------------------------------------------------------
18
+ # CONFIG
19
+ MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
20
+ DTYPE = torch.float16 # half-precision
21
+ MAX_AREA = 1280 * 720 # ≤720p
22
+ DEFAULT_FRAMES = 81 # ≈5s @16fps
23
+ # --------------------------------------------------------------------
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  def load_pipeline():
26
+ # 1) image encoder in full precision
27
+ image_encoder = CLIPVisionModel.from_pretrained(
28
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
29
  )
30
+ # 2) VAE in reduced precision
31
  vae = AutoencoderKLWan.from_pretrained(
32
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
33
  )
34
+ # 3) CLIPImageProcessor so we get the right class
35
+ image_processor = CLIPImageProcessor.from_pretrained(
36
+ MODEL_ID, subfolder="", torch_dtype=DTYPE
37
  )
38
+ # 4) load everything with a balanced device map
39
  pipe = WanImageToVideoPipeline.from_pretrained(
40
  MODEL_ID,
 
41
  vae=vae,
42
+ image_encoder=image_encoder,
43
+ image_processor=image_processor,
44
  torch_dtype=DTYPE,
45
+ device_map="balanced", # splits weights CPU/GPU
46
  )
 
 
 
 
 
47
  return pipe
48
 
49
+ # load once at import
50
  PIPE = load_pipeline()
51
 
52
+
53
+ # --------------------------------------------------------------------
54
+ # UTILS
55
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
56
+ """Resize while respecting multiples of the model’s patch size."""
57
+ ar = img.height / img.width
58
+ mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
59
+ h = round(np.sqrt(max_area * ar)) // mod * mod
60
+ w = round(np.sqrt(max_area / ar)) // mod * mod
61
  return img.resize((w, h), Image.LANCZOS), h, w
62
 
63
+ def center_crop_resize(img: Image.Image, h, w):
64
+ """Crop-and-resize to exactly (h, w)."""
65
  ratio = max(w / img.width, h / img.height)
66
+ img = img.resize(
67
  (round(img.width * ratio), round(img.height * ratio)),
68
  Image.LANCZOS
69
  )
70
  return TF.center_crop(img, [h, w])
71
 
72
+
73
+ # --------------------------------------------------------------------
74
+ # GENERATE (with simple progress streaming)
75
  def generate(
76
  first_frame: Image.Image,
77
+ last_frame: Image.Image,
78
+ prompt: str,
79
+ negative_prompt: str,
80
+ steps: int,
81
+ guidance: float,
82
+ num_frames: int,
83
+ seed: int,
84
+ fps: int,
85
+ progress=gr.Progress(), # gradio’s built-in progress callback
86
  ):
87
+ # pick or set seed
88
  if seed == -1:
89
  seed = torch.seed()
90
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
91
 
92
+ # 0→10%: resize
93
+ progress(0.0, desc="Resizing first frame…")
94
+ first_frame, h, w = aspect_resize(first_frame)
95
+ if last_frame.size != first_frame.size:
96
+ progress(0.1, desc="Resizing last frame…")
97
  last_frame = center_crop_resize(last_frame, h, w)
98
 
99
+ # 10→20%: ready to run
100
+ progress(0.2, desc="Starting video inference…")
101
+ result = PIPE(
102
+ image=first_frame,
 
 
 
103
  last_image=last_frame,
104
  prompt=prompt,
105
+ negative_prompt=negative_prompt or None,
106
  height=h,
107
  width=w,
108
  num_frames=num_frames,
109
  num_inference_steps=steps,
110
  guidance_scale=guidance,
111
  generator=gen,
 
112
  )
113
 
114
+ # 80→100%: export
115
+ progress(0.8, desc="Assembling video file…")
116
+ video_path = export_to_video(result.frames[0], fps=fps)
117
+ progress(1.0, desc="Done!")
118
+
119
+ # return path so gr.File offers immediate download, plus seed used
120
  return video_path, seed
121
 
122
+
123
+ # --------------------------------------------------------------------
124
+ # UI
125
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
126
  gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
127
 
 
129
  first_img = gr.Image(label="First frame", type="pil")
130
  last_img = gr.Image(label="Last frame", type="pil")
131
 
132
+ prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
133
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
134
 
135
  with gr.Accordion("Advanced parameters", open=False):
136
+ steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
137
+ guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
138
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
139
+ fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
140
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
141
 
142
+ run_btn = gr.Button("Generate")
143
+ download = gr.File(label="Download video", interactive=False)
144
+ used_seed = gr.Number(label="Seed used", interactive=False)
145
 
146
  run_btn.click(
147
  fn=generate,
148
+ inputs=[first_img, last_img, prompt, negative,
149
+ steps, guidance, num_frames, seed, fps],
150
+ outputs=[download, used_seed],
151
  )
152
 
153
+ # queue tasks so users see the little task-queue progress bar
154
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)