刘虹雨 commited on
Commit
5ccc87f
·
1 Parent(s): ee74488

update code

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -446,7 +446,7 @@ def duplicate_batch(tensor, batch_size=2):
446
 
447
  @torch.no_grad()
448
  @spaces.GPU(duration=200)
449
- def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
450
  """
451
  Generate avatars from input images.
452
 
@@ -480,8 +480,6 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
480
  # ws_avg.to(device)
481
  # DiT_model.to(device)
482
  # Set up face verse for amimation
483
-
484
-
485
  if source_type == 'example':
486
  input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs'
487
  input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/demo_imgs'
@@ -502,8 +500,17 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
502
  raise ValueError("Batch size > 1 not implemented")
503
 
504
  image_dir = chunk[0]
 
 
 
 
 
 
 
 
 
 
505
 
506
- image_name = os.path.splitext(os.path.basename(image_dir))[0]
507
  dino_img, clip_image = image_process(image_dir, clip_image_processor, dino_img_processor, device)
508
 
509
  clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
@@ -606,8 +613,8 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
606
  Image.fromarray(final_out_show, 'RGB').save(os.path.join(save_frames_path_outshow, frame_name))
607
 
608
  # Generate videos
609
- images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + '_out.mp4'))
610
- images_to_video(save_frames_path_depth, os.path.join(save_path_base, image_name + '_depth.mp4'))
611
 
612
  logging.info(f"✅ Video generation completed successfully!")
613
  return os.path.join(save_path_base, image_name + video_name+ '_out.mp4'), os.path.join(save_path_base, image_name + video_name+'_depth.mp4')
@@ -624,30 +631,31 @@ def assert_input_image(input_image):
624
  raise gr.Error("No image selected or uploaded!")
625
 
626
  @spaces.GPU(duration=100)
627
- def process_image(input_image, source_type, is_style, save_dir):
628
  """ 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
629
  process_img_input_dir = os.path.join(save_dir, 'input_image')
630
  process_img_save_dir = os.path.join(save_dir, 'processed_img')
 
631
  os.makedirs(process_img_save_dir, exist_ok=True)
632
  os.makedirs(process_img_input_dir, exist_ok=True)
633
  if source_type == "example":
634
- return input_image, source_type
 
635
  else:
636
  # input_process_model.inference(input_image, process_img_save_dir)
637
- shutil.copy(input_image, process_img_input_dir)
638
  input_process_model.inference(process_img_input_dir, process_img_save_dir, is_img=True, is_video=False)
639
- img_name = os.path.basename(input_image)
640
- imge_dir = os.path.join(save_dir, 'processed_img/dataset/images512x512/input_image', img_name)
641
- return imge_dir, source_type # 这里替换成 处理用户上传图片的逻辑
642
 
643
  @spaces.GPU(duration=100)
644
  @torch.no_grad()
645
- def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
646
  """
647
  🎭 这个函数用于风格转换
648
  ✅ 你可以在这里填入你的风格化代码
649
  """
650
- pipeline_sd.to(device)
651
  src_img_pil = Image.open(processed_image)
652
  img_name = os.path.basename(processed_image)
653
  save_dir = os.path.join(save_base, 'style_img')
@@ -663,8 +671,8 @@ def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
663
  num_inference_steps=30,
664
  controlnet_conditioning_scale=1.5
665
  )['images'][0]
666
- trg_img_pil.save(os.path.join(save_dir, img_name))
667
- return os.path.join(save_dir, img_name) # 🚨 这里需要替换成你的风格转换逻辑
668
 
669
 
670
  def reset_flag():
@@ -808,6 +816,8 @@ def launch_gradio_app():
808
  is_from_example = gr.State(value=True)
809
  is_styled = gr.State(value=False)
810
  working_dir = gr.State()
 
 
811
 
812
  with gr.Row():
813
  with gr.Column(variant='panel'):
@@ -932,13 +942,14 @@ def launch_gradio_app():
932
  format="mp4", height=512, width=512,
933
  autoplay=True
934
  )
935
- def apply_style_and_mark(processed_image, style_choice, cfg, strength, working_dir):
936
- styled = style_transfer(processed_image, styles[style_choice], cfg, strength, working_dir)
937
  return styled, True
938
 
939
  def process_image_and_enable_style(input_image, source_type, is_styled, wd):
940
- processed_result, updated_source_type = process_image(input_image, source_type, is_styled, wd)
941
- return processed_result, updated_source_type, gr.update(interactive=True), gr.update(interactive=True)
 
942
  processed_image_button.click(
943
  fn=prepare_working_dir,
944
  inputs=[working_dir, is_styled],
@@ -947,17 +958,17 @@ def launch_gradio_app():
947
  ).success(
948
  fn=process_image_and_enable_style,
949
  inputs=[input_image, source_type, is_styled, working_dir],
950
- outputs=[processed_image, source_type, style_button, submit],
951
  queue=True
952
  )
953
  style_button.click(
954
  fn=apply_style_and_mark,
955
- inputs=[processed_image, style_choice, cfg_slider, strength_slider, working_dir],
956
  outputs=[style_image, is_styled]
957
  )
958
  submit.click(
959
  fn=avatar_generation,
960
- inputs=[processed_image, working_dir, video_input, source_type, is_styled, style_image],
961
  outputs=[output_video, output_video_1], # ⏳ 稍后展示视频
962
  queue=True
963
  )
 
446
 
447
  @torch.no_grad()
448
  @spaces.GPU(duration=200)
449
+ def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img, image_name_true):
450
  """
451
  Generate avatars from input images.
452
 
 
480
  # ws_avg.to(device)
481
  # DiT_model.to(device)
482
  # Set up face verse for amimation
 
 
483
  if source_type == 'example':
484
  input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs'
485
  input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/demo_imgs'
 
500
  raise ValueError("Batch size > 1 not implemented")
501
 
502
  image_dir = chunk[0]
503
+ image_name = os.path.splitext(image_name_true)[0]
504
+
505
+ # # image_name = os.path.splitext(os.path.basename(image_dir))[0]
506
+ # if source_type == 'custom':
507
+ # image_name = os.path.splitext(image_name_true)[0]
508
+ # else:
509
+ # image_name = os.path.splitext(os.path.basename(image_dir))[0]
510
+
511
+
512
+
513
 
 
514
  dino_img, clip_image = image_process(image_dir, clip_image_processor, dino_img_processor, device)
515
 
516
  clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
 
613
  Image.fromarray(final_out_show, 'RGB').save(os.path.join(save_frames_path_outshow, frame_name))
614
 
615
  # Generate videos
616
+ images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + video_name+ '_out.mp4'))
617
+ images_to_video(save_frames_path_depth, os.path.join(save_path_base, image_name + video_name+ '_out.mp4'))
618
 
619
  logging.info(f"✅ Video generation completed successfully!")
620
  return os.path.join(save_path_base, image_name + video_name+ '_out.mp4'), os.path.join(save_path_base, image_name + video_name+'_depth.mp4')
 
631
  raise gr.Error("No image selected or uploaded!")
632
 
633
  @spaces.GPU(duration=100)
634
+ def process_image(input_image_dir, source_type, is_style, save_dir):
635
  """ 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
636
  process_img_input_dir = os.path.join(save_dir, 'input_image')
637
  process_img_save_dir = os.path.join(save_dir, 'processed_img')
638
+ image_name_true = os.path.basename(input_image_dir)
639
  os.makedirs(process_img_save_dir, exist_ok=True)
640
  os.makedirs(process_img_input_dir, exist_ok=True)
641
  if source_type == "example":
642
+ input_image = Image.open(input_image)
643
+ return input_image, source_type, image_name_true
644
  else:
645
  # input_process_model.inference(input_image, process_img_save_dir)
646
+ shutil.copy(input_image_dir, process_img_input_dir)
647
  input_process_model.inference(process_img_input_dir, process_img_save_dir, is_img=True, is_video=False)
648
+ imge_dir = os.path.join(save_dir, 'processed_img/dataset/images512x512/input_image', image_name_true)
649
+ image = Image.open(imge_dir)
650
+ return image, source_type, image_name_true # 这里替换成 处理用户上传图片的逻辑
651
 
652
  @spaces.GPU(duration=100)
653
  @torch.no_grad()
654
+ def style_transfer(processed_image, style_prompt, cfg, strength, save_base,image_name_true):
655
  """
656
  🎭 这个函数用于风格转换
657
  ✅ 你可以在这里填入你的风格化代码
658
  """
 
659
  src_img_pil = Image.open(processed_image)
660
  img_name = os.path.basename(processed_image)
661
  save_dir = os.path.join(save_base, 'style_img')
 
671
  num_inference_steps=30,
672
  controlnet_conditioning_scale=1.5
673
  )['images'][0]
674
+ trg_img_pil.save(os.path.join(save_dir, image_name_true))
675
+ return trg_img_pil # 🚨 这里需要替换成你的风格转换逻辑
676
 
677
 
678
  def reset_flag():
 
816
  is_from_example = gr.State(value=True)
817
  is_styled = gr.State(value=False)
818
  working_dir = gr.State()
819
+ image_name_true = gr.State()
820
+
821
 
822
  with gr.Row():
823
  with gr.Column(variant='panel'):
 
942
  format="mp4", height=512, width=512,
943
  autoplay=True
944
  )
945
+ def apply_style_and_mark(processed_image, style_choice, cfg, strength, working_dir, image_name_true):
946
+ styled = style_transfer(processed_image, styles[style_choice], cfg, strength, working_dir, image_name_true)
947
  return styled, True
948
 
949
  def process_image_and_enable_style(input_image, source_type, is_styled, wd):
950
+ processed_result, updated_source_type, image_name_true = process_image(input_image, source_type, is_styled, wd)
951
+ return processed_result, updated_source_type, gr.update(interactive=True), gr.update(interactive=True), image_name_true
952
+
953
  processed_image_button.click(
954
  fn=prepare_working_dir,
955
  inputs=[working_dir, is_styled],
 
958
  ).success(
959
  fn=process_image_and_enable_style,
960
  inputs=[input_image, source_type, is_styled, working_dir],
961
+ outputs=[processed_image, source_type, style_button, submit, image_name_true],
962
  queue=True
963
  )
964
  style_button.click(
965
  fn=apply_style_and_mark,
966
+ inputs=[processed_image, style_choice, cfg_slider, strength_slider, working_dir, image_name_true],
967
  outputs=[style_image, is_styled]
968
  )
969
  submit.click(
970
  fn=avatar_generation,
971
+ inputs=[processed_image, working_dir, video_input, source_type, is_styled, style_image, image_name_true],
972
  outputs=[output_video, output_video_1], # ⏳ 稍后展示视频
973
  queue=True
974
  )