ginipick commited on
Commit
cc1ee0b
·
verified ·
1 Parent(s): 3b05042

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -178
app.py CHANGED
@@ -47,13 +47,14 @@ from transformers import SiglipImageProcessor, SiglipVisionModel
47
  from diffusers_helper.clip_vision import hf_clip_vision_encode
48
  from diffusers_helper.bucket_tools import find_nearest_bucket
49
 
50
-
51
  free_mem_gb = get_cuda_free_memory_gb(gpu)
52
  high_vram = free_mem_gb > 60
53
 
54
  print(f'Free VRAM {free_mem_gb} GB')
55
  print(f'High-VRAM Mode: {high_vram}')
56
 
 
57
  text_encoder = LlamaModel.from_pretrained(
58
  "hunyuanvideo-community/HunyuanVideo",
59
  subfolder='text_encoder',
@@ -93,12 +94,14 @@ transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
93
  torch_dtype=torch.bfloat16
94
  ).cpu()
95
 
 
96
  vae.eval()
97
  text_encoder.eval()
98
  text_encoder_2.eval()
99
  image_encoder.eval()
100
  transformer.eval()
101
 
 
102
  if not high_vram:
103
  vae.enable_slicing()
104
  vae.enable_tiling()
@@ -106,20 +109,22 @@ if not high_vram:
106
  transformer.high_quality_fp32_output_for_inference = True
107
  print('transformer.high_quality_fp32_output_for_inference = True')
108
 
 
109
  transformer.to(dtype=torch.bfloat16)
110
  vae.to(dtype=torch.float16)
111
  image_encoder.to(dtype=torch.float16)
112
  text_encoder.to(dtype=torch.float16)
113
  text_encoder_2.to(dtype=torch.float16)
114
 
 
115
  vae.requires_grad_(False)
116
  text_encoder.requires_grad_(False)
117
  text_encoder_2.requires_grad_(False)
118
  image_encoder.requires_grad_(False)
119
  transformer.requires_grad_(False)
120
 
 
121
  if not high_vram:
122
- # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
123
  DynamicSwapInstaller.install_model(transformer, device=gpu)
124
  DynamicSwapInstaller.install_model(text_encoder, device=gpu)
125
  else:
@@ -140,6 +145,7 @@ examples = [
140
  ["img_examples/3.png", "The woman dances elegantly among the blossoms, spinning slowly with flowing sleeves and graceful hand movements."]
141
  ]
142
 
 
143
  def generate_examples(input_image, prompt):
144
  t2v=False
145
  n_prompt=""
@@ -192,7 +198,8 @@ def generate_examples(input_image, prompt):
192
  yield (
193
  gr.update(),
194
  gr.update(visible=True, value=preview),
195
- desc, html,
 
196
  gr.update(interactive=False),
197
  gr.update(interactive=True)
198
  )
@@ -211,98 +218,69 @@ def generate_examples(input_image, prompt):
211
  @torch.no_grad()
212
  def worker(
213
  input_image, prompt, n_prompt, seed,
214
- total_second_length, latent_window_size,
215
- steps, cfg, gs, rs,
216
- gpu_memory_preservation, use_teacache, mp4_crf
217
  ):
 
218
  total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
219
  total_latent_sections = int(max(round(total_latent_sections), 1))
220
 
221
  job_id = generate_timestamp()
222
 
223
- stream.output_queue.push(
224
- ('progress', (None, '', make_progress_bar_html(0, 'Starting ...')))
225
- )
226
 
227
  try:
228
- # Clean GPU if VRAM is low
229
  if not high_vram:
230
  unload_complete_models(
231
  text_encoder, text_encoder_2, image_encoder, vae, transformer
232
  )
233
 
234
  # Text encoding
235
- stream.output_queue.push(
236
- ('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...')))
237
- )
238
 
239
  if not high_vram:
240
  fake_diffusers_current_device(text_encoder, gpu)
241
  load_model_as_complete(text_encoder_2, target_device=gpu)
242
 
243
- llama_vec, clip_l_pooler = encode_prompt_conds(
244
- prompt, text_encoder, text_encoder_2,
245
- tokenizer, tokenizer_2
246
- )
247
 
248
  if cfg == 1:
249
- llama_vec_n, clip_l_pooler_n = (
250
- torch.zeros_like(llama_vec),
251
- torch.zeros_like(clip_l_pooler)
252
- )
253
  else:
254
- llama_vec_n, clip_l_pooler_n = encode_prompt_conds(
255
- n_prompt, text_encoder, text_encoder_2,
256
- tokenizer, tokenizer_2
257
- )
258
 
259
  llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
260
  llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
261
 
262
- # Processing input image
263
- stream.output_queue.push(
264
- ('progress', (None, '', make_progress_bar_html(0, 'Image processing ...')))
265
- )
266
 
267
  H, W, C = input_image.shape
268
  height, width = find_nearest_bucket(H, W, resolution=640)
269
- input_image_np = resize_and_center_crop(
270
- input_image,
271
- target_width=width,
272
- target_height=height
273
- )
274
 
275
- Image.fromarray(input_image_np).save(
276
- os.path.join(outputs_folder, f'{job_id}.png')
277
- )
278
 
279
  input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
280
  input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
281
 
282
  # VAE encoding
283
- stream.output_queue.push(
284
- ('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...')))
285
- )
286
 
287
  if not high_vram:
288
  load_model_as_complete(vae, target_device=gpu)
289
-
290
  start_latent = vae_encode(input_image_pt, vae)
291
 
292
  # CLIP Vision
293
- stream.output_queue.push(
294
- ('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...')))
295
- )
296
 
297
  if not high_vram:
298
  load_model_as_complete(image_encoder, target_device=gpu)
299
-
300
- image_encoder_output = hf_clip_vision_encode(
301
- input_image_np, feature_extractor, image_encoder
302
- )
303
  image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
304
 
305
- # Dtype
306
  llama_vec = llama_vec.to(transformer.dtype)
307
  llama_vec_n = llama_vec_n.to(transformer.dtype)
308
  clip_l_pooler = clip_l_pooler.to(transformer.dtype)
@@ -310,9 +288,7 @@ def worker(
310
  image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
311
 
312
  # Start sampling
313
- stream.output_queue.push(
314
- ('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...')))
315
- )
316
 
317
  rnd = torch.Generator("cpu").manual_seed(seed)
318
 
@@ -322,10 +298,8 @@ def worker(
322
  ).cpu()
323
  history_pixels = None
324
 
325
- history_latents = torch.cat(
326
- [history_latents, start_latent.to(history_latents)],
327
- dim=2
328
- )
329
  total_generated_latent_frames = 1
330
 
331
  for section_index in range(total_latent_sections):
@@ -351,10 +325,7 @@ def worker(
351
  preview = d['denoised']
352
  preview = vae_decode_fake(preview)
353
  preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
354
- preview = einops.rearrange(
355
- preview,
356
- 'b c t h w -> (b h) (t w) c'
357
- )
358
 
359
  if stream.input_queue.top() == 'end':
360
  stream.output_queue.push(('end', None))
@@ -363,15 +334,12 @@ def worker(
363
  current_step = d['i'] + 1
364
  percentage = int(100.0 * current_step / steps)
365
  hint = f'Sampling {current_step}/{steps}'
366
- desc = f'Section {section_index+1}/{total_latent_sections}'
367
- stream.output_queue.push(
368
- ('progress', (preview, desc, make_progress_bar_html(percentage, hint)))
369
- )
370
  return
371
 
372
  indices = torch.arange(
373
- 0,
374
- sum([1, 16, 2, 1, latent_window_size])
375
  ).unsqueeze(0)
376
  (
377
  clean_latent_indices_start,
@@ -380,14 +348,13 @@ def worker(
380
  clean_latent_1x_indices,
381
  latent_indices
382
  ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
383
- clean_latent_indices = torch.cat(
384
- [clean_latent_indices_start, clean_latent_1x_indices],
385
- dim=1
386
- )
387
 
388
  clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[
389
  :, :, -sum([16, 2, 1]):, :, :
390
  ].split([16, 2, 1], dim=2)
 
391
  clean_latents = torch.cat(
392
  [start_latent.to(history_latents), clean_latents_1x],
393
  dim=2
@@ -424,21 +391,13 @@ def worker(
424
  )
425
 
426
  total_generated_latent_frames += int(generated_latents.shape[2])
427
- history_latents = torch.cat(
428
- [history_latents, generated_latents.to(history_latents)],
429
- dim=2
430
- )
431
 
432
  if not high_vram:
433
- offload_model_from_device_for_memory_preservation(
434
- transformer, target_device=gpu,
435
- preserved_memory_gb=8
436
- )
437
  load_model_as_complete(vae, target_device=gpu)
438
 
439
- real_history_latents = history_latents[
440
- :, :, -total_generated_latent_frames:, :, :
441
- ]
442
 
443
  if history_pixels is None:
444
  history_pixels = vae_decode(real_history_latents, vae).cpu()
@@ -456,75 +415,55 @@ def worker(
456
  if not high_vram:
457
  unload_complete_models()
458
 
459
- output_filename = os.path.join(
460
- outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4'
461
- )
462
-
463
- save_bcthw_as_mp4(
464
- history_pixels, output_filename,
465
- fps=30, crf=mp4_crf
466
- )
467
 
468
- print(
469
- f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}'
470
- )
471
 
472
  stream.output_queue.push(('file', output_filename))
473
 
474
  except:
475
  traceback.print_exc()
476
  if not high_vram:
477
- unload_complete_models(
478
- text_encoder, text_encoder_2, image_encoder, vae, transformer
479
- )
480
 
481
  stream.output_queue.push(('end', None))
482
  return
483
 
484
  def get_duration(
485
- input_image, prompt, t2v, n_prompt, seed,
486
- total_second_length, latent_window_size, steps,
487
- cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
 
488
  ):
489
  return total_second_length * 60
490
 
491
  @spaces.GPU(duration=get_duration)
492
  def process(
493
- input_image, prompt, t2v=False, n_prompt="",
494
- seed=31337, total_second_length=5, latent_window_size=9,
495
- steps=25, cfg=1.0, gs=10.0, rs=0.0,
496
- gpu_memory_preservation=6, use_teacache=True, mp4_crf=16
497
  ):
498
  global stream
499
-
500
  if t2v:
501
  default_height, default_width = 640, 640
502
- input_image = np.ones(
503
- (default_height, default_width, 3),
504
- dtype=np.uint8
505
- ) * 255
506
  print("No input image provided. Using a blank white image.")
507
  else:
508
- # ImageEditor에서 받은 composite RGBA를 분리
509
  composite_rgba_uint8 = input_image["composite"]
510
 
511
- # rgb_uint8: (H,W,3)
512
  rgb_uint8 = composite_rgba_uint8[:, :, :3]
513
- # mask_uint8: (H,W)
514
  mask_uint8 = composite_rgba_uint8[:, :, 3]
515
 
516
- # 흰색 배경
517
  h, w = rgb_uint8.shape[:2]
518
  background_uint8 = np.full((h, w, 3), 255, dtype=np.uint8)
519
 
520
- # 알파 노멀라이즈
521
  alpha_normalized_float32 = mask_uint8.astype(np.float32) / 255.0
522
  alpha_mask_float32 = np.stack([alpha_normalized_float32]*3, axis=2)
523
 
524
- # 알파 블렌딩
525
- blended_image_float32 = \
526
- rgb_uint8.astype(np.float32) * alpha_mask_float32 + \
527
- background_uint8.astype(np.float32) * (1.0 - alpha_mask_float32)
528
 
529
  input_image = np.clip(blended_image_float32, 0, 255).astype(np.uint8)
530
 
@@ -559,7 +498,8 @@ def process(
559
  yield (
560
  gr.update(),
561
  gr.update(visible=True, value=preview),
562
- desc, html,
 
563
  gr.update(interactive=False),
564
  gr.update(interactive=True)
565
  )
@@ -578,16 +518,16 @@ def process(
578
  def end_process():
579
  stream.input_queue.push('end')
580
 
 
581
  quick_prompts = [
582
  'The girl dances gracefully, with clear movements, full of charm.',
583
  'A character doing some simple body movements.'
584
  ]
585
  quick_prompts = [[x] for x in quick_prompts]
586
 
587
- # 기존 CSS + 추가로 UI 개선용
588
  def make_custom_css():
589
  base_progress_css = make_progress_bar_css()
590
- # 아래는 예시로 약간 더 파스텔 톤의 스타일 및 카드형 UI
591
  extra_css = """
592
  body {
593
  background: #fafbfe !important;
@@ -595,14 +535,14 @@ def make_custom_css():
595
  }
596
  #title-container {
597
  text-align: center;
598
- padding: 30px;
599
  background: linear-gradient(135deg, #a8c0ff 0%, #fbc2eb 100%);
600
- border-radius: 0 0 16px 16px;
601
  margin-bottom: 20px;
602
  }
603
  #title-container h1 {
604
  color: white;
605
- font-size: 2.2rem;
606
  margin: 0;
607
  font-weight: 800;
608
  text-shadow: 1px 2px 2px rgba(0,0,0,0.1);
@@ -650,35 +590,30 @@ css = make_custom_css()
650
 
651
  block = gr.Blocks(css=css).queue()
652
  with block:
653
- # 상단 그라디언트 영역
654
- with gr.Box(elem_id="title-container"):
655
  gr.Markdown("<h1>FramePack I2V</h1>")
656
 
657
- # 설명 부분
658
  gr.Markdown("""
659
  ### Video diffusion, but feels like image diffusion
660
- FramePack I2V - a model that predicts future frames from history frames,
661
- enabling you to generate short animations from a single image and a text prompt.<br><br>
662
- ***beta FramePack Fill*** - You can also paint over the input image to inpaint the video output.
663
  """)
664
 
665
  with gr.Row():
666
  with gr.Column():
667
  input_image = gr.ImageEditor(
668
  type="numpy",
669
- label="Image (click 'Brush' tool to mask)",
670
  height=320,
671
  brush=gr.Brush(colors=["#ffffff"])
672
  )
673
  prompt = gr.Textbox(label="Prompt", value='')
 
674
 
675
- t2v = gr.Checkbox(
676
- label="Generate from Text Only (no image)?",
677
- value=False
678
- )
679
  example_quick_prompts = gr.Dataset(
680
  samples=quick_prompts,
681
- label="Quick Prompt Picks",
682
  samples_per_page=1000,
683
  components=[prompt]
684
  )
@@ -695,7 +630,7 @@ with block:
695
  end_button = gr.Button(value="Stop Generation", elem_id="stop-button", interactive=False)
696
 
697
  total_second_length = gr.Slider(
698
- label="Total Video Length (sec)",
699
  minimum=1,
700
  maximum=5,
701
  value=2,
@@ -707,87 +642,81 @@ with block:
707
  use_teacache = gr.Checkbox(
708
  label='Use TeaCache',
709
  value=True,
710
- info='Faster speed but can degrade finger/hand details'
711
  )
712
  n_prompt = gr.Textbox(label="Negative Prompt", value="", visible=False)
713
  seed = gr.Number(label="Seed", value=31337, precision=0)
714
-
715
  latent_window_size = gr.Slider(
716
  label="Latent Window Size",
717
- minimum=1,
718
- maximum=33,
719
- value=9,
720
- step=1,
721
  visible=False
722
  )
723
  steps = gr.Slider(
724
  label="Steps",
725
- minimum=1,
726
- maximum=100,
727
- value=25,
728
- step=1,
729
- info='Not recommended to change significantly.'
730
  )
731
  cfg = gr.Slider(
732
  label="CFG Scale",
733
- minimum=1.0,
734
- maximum=32.0,
735
- value=1.0,
736
- step=0.01,
737
  visible=False
738
  )
739
  gs = gr.Slider(
740
  label="Distilled CFG Scale",
741
- minimum=1.0,
742
- maximum=32.0,
743
- value=10.0,
744
- step=0.01,
745
- info='Not recommended to change significantly.'
746
  )
747
  rs = gr.Slider(
748
  label="CFG Re-Scale",
749
- minimum=0.0,
750
- maximum=1.0,
751
- value=0.0,
752
- step=0.01,
753
  visible=False
754
  )
755
  gpu_memory_preservation = gr.Slider(
756
  label="GPU Memory Preservation (GB)",
757
- minimum=6,
758
- maximum=128,
759
- value=6,
760
- step=0.1,
761
- info="Increase if OOM occurs (slower speed)."
762
  )
763
  mp4_crf = gr.Slider(
764
  label="MP4 Compression (CRF)",
765
- minimum=0,
766
- maximum=100,
767
- value=16,
768
- step=1,
769
- info="Lower is higher quality. 16 is recommended."
770
  )
771
 
772
  with gr.Column():
773
- preview_image = gr.Image(label="Preview Latents", height=200, visible=False)
774
- result_video = gr.Video(label="Generated Video", autoplay=True, height=512, loop=True)
775
-
 
 
 
 
 
 
 
 
 
776
  progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
777
  progress_bar = gr.HTML('', elem_classes='no-generating-animation')
778
 
 
779
  gr.HTML("""
780
  <div style="text-align:center; margin-top:20px;">
781
- Share your creations or find inspiration by searching
782
- <a href="https://x.com/search?q=framepack&f=live" target="_blank">#framepack</a> on Twitter (X)!
783
  </div>
784
  """)
785
 
786
- # 함수 연결
787
  ips = [
788
  input_image, prompt, t2v, n_prompt, seed,
789
- total_second_length, latent_window_size, steps,
790
- cfg, gs, rs, gpu_memory_preservation,
791
  use_teacache, mp4_crf
792
  ]
793
  start_button.click(
@@ -797,7 +726,7 @@ with block:
797
  )
798
  end_button.click(fn=end_process)
799
 
800
- # 예제 버튼 (원한다면 주석 해제)
801
  # gr.Examples(
802
  # examples=examples,
803
  # inputs=[input_image, prompt],
 
47
  from diffusers_helper.clip_vision import hf_clip_vision_encode
48
  from diffusers_helper.bucket_tools import find_nearest_bucket
49
 
50
+ # Check GPU memory
51
  free_mem_gb = get_cuda_free_memory_gb(gpu)
52
  high_vram = free_mem_gb > 60
53
 
54
  print(f'Free VRAM {free_mem_gb} GB')
55
  print(f'High-VRAM Mode: {high_vram}')
56
 
57
+ # Load models
58
  text_encoder = LlamaModel.from_pretrained(
59
  "hunyuanvideo-community/HunyuanVideo",
60
  subfolder='text_encoder',
 
94
  torch_dtype=torch.bfloat16
95
  ).cpu()
96
 
97
+ # Evaluation mode
98
  vae.eval()
99
  text_encoder.eval()
100
  text_encoder_2.eval()
101
  image_encoder.eval()
102
  transformer.eval()
103
 
104
+ # Slicing/Tiling for low VRAM
105
  if not high_vram:
106
  vae.enable_slicing()
107
  vae.enable_tiling()
 
109
  transformer.high_quality_fp32_output_for_inference = True
110
  print('transformer.high_quality_fp32_output_for_inference = True')
111
 
112
+ # Move to correct dtype
113
  transformer.to(dtype=torch.bfloat16)
114
  vae.to(dtype=torch.float16)
115
  image_encoder.to(dtype=torch.float16)
116
  text_encoder.to(dtype=torch.float16)
117
  text_encoder_2.to(dtype=torch.float16)
118
 
119
+ # No gradient
120
  vae.requires_grad_(False)
121
  text_encoder.requires_grad_(False)
122
  text_encoder_2.requires_grad_(False)
123
  image_encoder.requires_grad_(False)
124
  transformer.requires_grad_(False)
125
 
126
+ # DynamicSwap if low VRAM
127
  if not high_vram:
 
128
  DynamicSwapInstaller.install_model(transformer, device=gpu)
129
  DynamicSwapInstaller.install_model(text_encoder, device=gpu)
130
  else:
 
145
  ["img_examples/3.png", "The woman dances elegantly among the blossoms, spinning slowly with flowing sleeves and graceful hand movements."]
146
  ]
147
 
148
+ # Example generation (optional)
149
  def generate_examples(input_image, prompt):
150
  t2v=False
151
  n_prompt=""
 
198
  yield (
199
  gr.update(),
200
  gr.update(visible=True, value=preview),
201
+ desc,
202
+ html,
203
  gr.update(interactive=False),
204
  gr.update(interactive=True)
205
  )
 
218
  @torch.no_grad()
219
  def worker(
220
  input_image, prompt, n_prompt, seed,
221
+ total_second_length, latent_window_size, steps,
222
+ cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
 
223
  ):
224
+ # Calculate total sections
225
  total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
226
  total_latent_sections = int(max(round(total_latent_sections), 1))
227
 
228
  job_id = generate_timestamp()
229
 
230
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
 
 
231
 
232
  try:
233
+ # Unload if VRAM is low
234
  if not high_vram:
235
  unload_complete_models(
236
  text_encoder, text_encoder_2, image_encoder, vae, transformer
237
  )
238
 
239
  # Text encoding
240
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
 
 
241
 
242
  if not high_vram:
243
  fake_diffusers_current_device(text_encoder, gpu)
244
  load_model_as_complete(text_encoder_2, target_device=gpu)
245
 
246
+ llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
 
 
 
247
 
248
  if cfg == 1:
249
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
 
 
 
250
  else:
251
+ llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
 
 
 
252
 
253
  llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
254
  llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
255
 
256
+ # Process image
257
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
 
 
258
 
259
  H, W, C = input_image.shape
260
  height, width = find_nearest_bucket(H, W, resolution=640)
261
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
 
 
 
 
262
 
263
+ Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
 
 
264
 
265
  input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
266
  input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
267
 
268
  # VAE encoding
269
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
 
 
270
 
271
  if not high_vram:
272
  load_model_as_complete(vae, target_device=gpu)
 
273
  start_latent = vae_encode(input_image_pt, vae)
274
 
275
  # CLIP Vision
276
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
 
 
277
 
278
  if not high_vram:
279
  load_model_as_complete(image_encoder, target_device=gpu)
280
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
 
 
 
281
  image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
282
 
283
+ # Convert dtype
284
  llama_vec = llama_vec.to(transformer.dtype)
285
  llama_vec_n = llama_vec_n.to(transformer.dtype)
286
  clip_l_pooler = clip_l_pooler.to(transformer.dtype)
 
288
  image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
289
 
290
  # Start sampling
291
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
 
 
292
 
293
  rnd = torch.Generator("cpu").manual_seed(seed)
294
 
 
298
  ).cpu()
299
  history_pixels = None
300
 
301
+ # Add start_latent
302
+ history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
 
 
303
  total_generated_latent_frames = 1
304
 
305
  for section_index in range(total_latent_sections):
 
325
  preview = d['denoised']
326
  preview = vae_decode_fake(preview)
327
  preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
328
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
 
 
 
329
 
330
  if stream.input_queue.top() == 'end':
331
  stream.output_queue.push(('end', None))
 
334
  current_step = d['i'] + 1
335
  percentage = int(100.0 * current_step / steps)
336
  hint = f'Sampling {current_step}/{steps}'
337
+ desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}'
338
+ stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
 
 
339
  return
340
 
341
  indices = torch.arange(
342
+ 0, sum([1, 16, 2, 1, latent_window_size])
 
343
  ).unsqueeze(0)
344
  (
345
  clean_latent_indices_start,
 
348
  clean_latent_1x_indices,
349
  latent_indices
350
  ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
351
+
352
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
 
 
353
 
354
  clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[
355
  :, :, -sum([16, 2, 1]):, :, :
356
  ].split([16, 2, 1], dim=2)
357
+
358
  clean_latents = torch.cat(
359
  [start_latent.to(history_latents), clean_latents_1x],
360
  dim=2
 
391
  )
392
 
393
  total_generated_latent_frames += int(generated_latents.shape[2])
394
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
 
 
 
395
 
396
  if not high_vram:
397
+ offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
 
 
 
398
  load_model_as_complete(vae, target_device=gpu)
399
 
400
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
 
 
401
 
402
  if history_pixels is None:
403
  history_pixels = vae_decode(real_history_latents, vae).cpu()
 
415
  if not high_vram:
416
  unload_complete_models()
417
 
418
+ output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
419
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=mp4_crf)
 
 
 
 
 
 
420
 
421
+ print(f'Decoded. Latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
 
 
422
 
423
  stream.output_queue.push(('file', output_filename))
424
 
425
  except:
426
  traceback.print_exc()
427
  if not high_vram:
428
+ unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
 
 
429
 
430
  stream.output_queue.push(('end', None))
431
  return
432
 
433
  def get_duration(
434
+ input_image, prompt, t2v, n_prompt,
435
+ seed, total_second_length, latent_window_size,
436
+ steps, cfg, gs, rs, gpu_memory_preservation,
437
+ use_teacache, mp4_crf
438
  ):
439
  return total_second_length * 60
440
 
441
  @spaces.GPU(duration=get_duration)
442
  def process(
443
+ input_image, prompt, t2v=False, n_prompt="", seed=31337,
444
+ total_second_length=5, latent_window_size=9, steps=25,
445
+ cfg=1.0, gs=10.0, rs=0.0, gpu_memory_preservation=6,
446
+ use_teacache=True, mp4_crf=16
447
  ):
448
  global stream
 
449
  if t2v:
450
  default_height, default_width = 640, 640
451
+ input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
 
 
 
452
  print("No input image provided. Using a blank white image.")
453
  else:
 
454
  composite_rgba_uint8 = input_image["composite"]
455
 
 
456
  rgb_uint8 = composite_rgba_uint8[:, :, :3]
 
457
  mask_uint8 = composite_rgba_uint8[:, :, 3]
458
 
 
459
  h, w = rgb_uint8.shape[:2]
460
  background_uint8 = np.full((h, w, 3), 255, dtype=np.uint8)
461
 
 
462
  alpha_normalized_float32 = mask_uint8.astype(np.float32) / 255.0
463
  alpha_mask_float32 = np.stack([alpha_normalized_float32]*3, axis=2)
464
 
465
+ blended_image_float32 = rgb_uint8.astype(np.float32) * alpha_mask_float32 + \
466
+ background_uint8.astype(np.float32) * (1.0 - alpha_mask_float32)
 
 
467
 
468
  input_image = np.clip(blended_image_float32, 0, 255).astype(np.uint8)
469
 
 
498
  yield (
499
  gr.update(),
500
  gr.update(visible=True, value=preview),
501
+ desc,
502
+ html,
503
  gr.update(interactive=False),
504
  gr.update(interactive=True)
505
  )
 
518
  def end_process():
519
  stream.input_queue.push('end')
520
 
521
+
522
  quick_prompts = [
523
  'The girl dances gracefully, with clear movements, full of charm.',
524
  'A character doing some simple body movements.'
525
  ]
526
  quick_prompts = [[x] for x in quick_prompts]
527
 
528
+
529
  def make_custom_css():
530
  base_progress_css = make_progress_bar_css()
 
531
  extra_css = """
532
  body {
533
  background: #fafbfe !important;
 
535
  }
536
  #title-container {
537
  text-align: center;
538
+ padding: 20px 0;
539
  background: linear-gradient(135deg, #a8c0ff 0%, #fbc2eb 100%);
540
+ border-radius: 0 0 10px 10px;
541
  margin-bottom: 20px;
542
  }
543
  #title-container h1 {
544
  color: white;
545
+ font-size: 2rem;
546
  margin: 0;
547
  font-weight: 800;
548
  text-shadow: 1px 2px 2px rgba(0,0,0,0.1);
 
590
 
591
  block = gr.Blocks(css=css).queue()
592
  with block:
593
+ # Title (use gr.Group instead of gr.Box for older Gradio versions)
594
+ with gr.Group(elem_id="title-container"):
595
  gr.Markdown("<h1>FramePack I2V</h1>")
596
 
 
597
  gr.Markdown("""
598
  ### Video diffusion, but feels like image diffusion
599
+ FramePack I2V - a model that predicts future frames from past frames,
600
+ letting you generate short animations from a single image plus text prompt.
 
601
  """)
602
 
603
  with gr.Row():
604
  with gr.Column():
605
  input_image = gr.ImageEditor(
606
  type="numpy",
607
+ label="Image Editor (use Brush for mask)",
608
  height=320,
609
  brush=gr.Brush(colors=["#ffffff"])
610
  )
611
  prompt = gr.Textbox(label="Prompt", value='')
612
+ t2v = gr.Checkbox(label="Only Text to Video (ignore image)?", value=False)
613
 
 
 
 
 
614
  example_quick_prompts = gr.Dataset(
615
  samples=quick_prompts,
616
+ label="Quick Prompts",
617
  samples_per_page=1000,
618
  components=[prompt]
619
  )
 
630
  end_button = gr.Button(value="Stop Generation", elem_id="stop-button", interactive=False)
631
 
632
  total_second_length = gr.Slider(
633
+ label="Total Video Length (Seconds)",
634
  minimum=1,
635
  maximum=5,
636
  value=2,
 
642
  use_teacache = gr.Checkbox(
643
  label='Use TeaCache',
644
  value=True,
645
+ info='Faster speed, but may worsen hands/fingers.'
646
  )
647
  n_prompt = gr.Textbox(label="Negative Prompt", value="", visible=False)
648
  seed = gr.Number(label="Seed", value=31337, precision=0)
 
649
  latent_window_size = gr.Slider(
650
  label="Latent Window Size",
651
+ minimum=1, maximum=33,
652
+ value=9, step=1,
 
 
653
  visible=False
654
  )
655
  steps = gr.Slider(
656
  label="Steps",
657
+ minimum=1, maximum=100,
658
+ value=25, step=1,
659
+ info='Not recommended to change drastically.'
 
 
660
  )
661
  cfg = gr.Slider(
662
  label="CFG Scale",
663
+ minimum=1.0, maximum=32.0,
664
+ value=1.0, step=0.01,
 
 
665
  visible=False
666
  )
667
  gs = gr.Slider(
668
  label="Distilled CFG Scale",
669
+ minimum=1.0, maximum=32.0,
670
+ value=10.0, step=0.01,
671
+ info='Not recommended to change drastically.'
 
 
672
  )
673
  rs = gr.Slider(
674
  label="CFG Re-Scale",
675
+ minimum=0.0, maximum=1.0,
676
+ value=0.0, step=0.01,
 
 
677
  visible=False
678
  )
679
  gpu_memory_preservation = gr.Slider(
680
  label="GPU Memory Preservation (GB)",
681
+ minimum=6, maximum=128,
682
+ value=6, step=0.1,
683
+ info="Increase if OOM occurs, but slower."
 
 
684
  )
685
  mp4_crf = gr.Slider(
686
  label="MP4 Compression (CRF)",
687
+ minimum=0, maximum=100,
688
+ value=16, step=1,
689
+ info="Lower = better quality. 16 recommended."
 
 
690
  )
691
 
692
  with gr.Column():
693
+ preview_image = gr.Image(
694
+ label="Preview Latents",
695
+ height=200,
696
+ visible=False
697
+ )
698
+ result_video = gr.Video(
699
+ label="Finished Frames",
700
+ autoplay=True,
701
+ show_share_button=False,
702
+ height=512,
703
+ loop=True
704
+ )
705
  progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
706
  progress_bar = gr.HTML('', elem_classes='no-generating-animation')
707
 
708
+ # Extra info
709
  gr.HTML("""
710
  <div style="text-align:center; margin-top:20px;">
711
+ Share your outputs or get inspired by searching
712
+ <a href="https://x.com/search?q=framepack&f=live" target="_blank">#framepack</a> on Twitter!
713
  </div>
714
  """)
715
 
 
716
  ips = [
717
  input_image, prompt, t2v, n_prompt, seed,
718
+ total_second_length, latent_window_size,
719
+ steps, cfg, gs, rs, gpu_memory_preservation,
720
  use_teacache, mp4_crf
721
  ]
722
  start_button.click(
 
726
  )
727
  end_button.click(fn=end_process)
728
 
729
+ # If you want examples, uncomment below:
730
  # gr.Examples(
731
  # examples=examples,
732
  # inputs=[input_image, prompt],