刘虹雨 commited on
Commit
b1ae546
·
1 Parent(s): 66886ce

update code

Browse files
Files changed (1) hide show
  1. app.py +219 -193
app.py CHANGED
@@ -3,7 +3,17 @@ import subprocess
3
  import sys
4
  import warnings
5
  import logging
6
- import spaces
 
 
 
 
 
 
 
 
 
 
7
  import difflib
8
 
9
  # Configure logging settings
@@ -448,6 +458,7 @@ def duplicate_batch(tensor, batch_size=2):
448
  @torch.no_grad()
449
  @spaces.GPU(duration=200)
450
  def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img, image_name_true):
 
451
  """
452
  Generate avatars from input images.
453
 
@@ -463,162 +474,165 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
463
  mean (torch.Tensor): Mean normalization tensor.
464
  ws_avg (torch.Tensor): Latent average tensor.
465
  """
466
- if is_styled:
467
- items = [styled_img]
468
- else:
469
- items = [items]
470
- video_folder = "./demo_data/target_video"
471
- video_name = os.path.basename(video_path_input).split(".")[0]
472
- target_path = os.path.join(video_folder, 'data_' + video_name)
473
- exp_base_dir = os.path.join(target_path, 'coeffs')
474
- exp_img_base_dir = os.path.join(target_path, 'images512x512')
475
- motion_base_dir = os.path.join(target_path, 'motions')
476
- label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
477
- # render_model.to(device)
478
- # image_encoder.to(device)
479
- # vae_triplane.to(device)
480
- # dinov2.to(device)
481
- # ws_avg.to(device)
482
- # DiT_model.to(device)
483
- # Set up face verse for amimation
484
- if source_type == 'example':
485
- input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs'
486
- input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/demo_imgs'
487
- elif source_type == 'custom':
488
- input_img_fvid = os.path.join(save_path_base, 'processed_img/dataset/coeffs/input_image')
489
- input_img_motion = os.path.join(save_path_base, 'processed_img/dataset/motions/input_image')
490
- else:
491
- raise ValueError("Wrong type")
492
- bs = 1
493
- sample_steps = 20
494
- cfg_scale = 4.5
495
- pitch_range = 0.25
496
- yaw_range = 0.35
497
- triplane_size = (256 * 4, 256)
498
- latent_size = (triplane_size[0] // 8, triplane_size[1] // 8)
499
- for chunk in tqdm(list(get_chunks(items, 1)), unit='batch'):
500
- if bs != 1:
501
- raise ValueError("Batch size > 1 not implemented")
502
-
503
- image_dir = chunk[0]
504
- image_name = os.path.splitext(image_name_true)[0]
505
-
506
- # # image_name = os.path.splitext(os.path.basename(image_dir))[0]
507
- # if source_type == 'custom':
508
- # image_name = os.path.splitext(image_name_true)[0]
509
- # else:
510
- # image_name = os.path.splitext(os.path.basename(image_dir))[0]
511
-
512
-
513
-
514
-
515
- dino_img, clip_image = image_process(image_dir, clip_image_processor, dino_img_processor, device)
516
-
517
- clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
518
- uncond_clip_feature = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
519
- -2]
520
- dino_feature = dinov2(dino_img).last_hidden_state
521
- uncond_dino_feature = dinov2(torch.zeros_like(dino_img)).last_hidden_state
522
-
523
- samples = generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature,
524
- uncond_clip_feature, uncond_dino_feature, device, latent_size,
525
- 'dpm-solver')
526
-
527
- samples = (samples / 0.3994218)
528
- samples = rearrange(samples, "b c (f h) w -> b c f h w", f=4)
529
- samples = vae_triplane.decode(samples)
530
- samples = rearrange(samples, "b c f h w -> b f c h w")
531
- samples = samples * std + mean
532
- torch.cuda.empty_cache()
533
- torch.cuda.ipc_collect()
534
- save_frames_path_out = os.path.join(save_path_base, image_name, video_name, 'out')
535
- save_frames_path_outshow = os.path.join(save_path_base, image_name, video_name,'out_show')
536
- save_frames_path_depth = os.path.join(save_path_base, image_name, video_name, 'depth')
537
-
538
- os.makedirs(save_frames_path_out, exist_ok=True)
539
- os.makedirs(save_frames_path_outshow, exist_ok=True)
540
- os.makedirs(save_frames_path_depth, exist_ok=True)
541
-
542
- img_ref = np.array(Image.open(image_dir))
543
- img_ref_out = img_ref.copy()
544
- img_ref = torch.from_numpy(img_ref.astype(np.float32) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0).to(device)
545
-
546
- motion_app_dir = os.path.join(input_img_motion, image_name + '.npy')
547
- motion_app = torch.tensor(np.load(motion_app_dir), dtype=torch.float32).unsqueeze(0).to(device)
548
-
549
- id_motions = os.path.join(input_img_fvid, image_name + '.npy')
550
-
551
- all_pose = json.loads(open(label_file_test).read())['labels']
552
- all_pose = dict(all_pose)
553
- if os.path.exists(id_motions):
554
- coeff = np.load(id_motions).astype(np.float32)
555
- coeff = torch.from_numpy(coeff).to(device).float().unsqueeze(0)
556
- Faceverse.id_coeff = Faceverse.recon_model.split_coeffs(coeff)[0]
557
- motion_dir = os.path.join(motion_base_dir, video_name)
558
- exp_dir = os.path.join(exp_base_dir, video_name)
559
- for frame_index, motion_name in enumerate(
560
- tqdm(natsorted(os.listdir(motion_dir), alg=ns.PATH), desc="Processing Frames")):
561
- exp_each_dir_img = os.path.join(exp_img_base_dir, video_name, motion_name.replace('.npy', '.png'))
562
- exp_each_dir = os.path.join(exp_dir, motion_name)
563
- motion_each_dir = os.path.join(motion_dir, motion_name)
564
-
565
- # Load pose data
566
- pose_key = os.path.join(video_name, motion_name.replace('.npy', '.png'))
567
-
568
- cam2world_pose = LookAtPoseSampler.sample(
569
- 3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
570
- 3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
571
- torch.tensor([0, 0, 0], device=device), radius=2.7, device=device)
572
- pose_show = torch.cat([cam2world_pose.reshape(-1, 16),
573
- FOV_to_intrinsics(fov_degrees=18.837, device=device).reshape(-1, 9)], 1).to(device)
574
-
575
- pose = torch.tensor(np.array(all_pose[pose_key]).astype(np.float32)).float().unsqueeze(0).to(device)
576
-
577
- # Load and resize expression image
578
- exp_img = np.array(Image.open(exp_each_dir_img).resize((512, 512)))
579
-
580
- # Load expression coefficients
581
- exp_coeff = torch.from_numpy(np.load(exp_each_dir).astype(np.float32)).to(device).float().unsqueeze(0)
582
- exp_target = Faceverse.make_driven_rendering(exp_coeff, res=256)
583
-
584
- # Load motion data
585
- motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device)
586
-
587
- # img_ref_double = duplicate_batch(img_ref, batch_size=2)
588
- # motion_app_double = duplicate_batch(motion_app, batch_size=2)
589
- # motion_double = duplicate_batch(motion, batch_size=2)
590
- # pose_double = torch.cat([pose_show, pose], dim=0)
591
- # exp_target_double = duplicate_batch(exp_target, batch_size=2)
592
- # samples_double = duplicate_batch(samples, batch_size=2)
593
- # Select refine_net processing method
594
- final_out = render_model(
595
- img_ref, None, motion_app, motion, c=pose, mesh=exp_target,
596
- triplane_recon=samples,
597
- ws_avg=ws_avg, motion_scale=1.
598
- )
599
-
600
- # Process output image
601
- final_out_show = trans(final_out['image_sr'][0].unsqueeze(0))
602
- final_out_notshow = trans(final_out['image_sr'][0].unsqueeze(0))
603
- depth = final_out['image_depth'][0].unsqueeze(0)
604
- depth = -depth
605
- depth = (depth - depth.min()) / (depth.max() - depth.min()) * 2 - 1
606
- depth = trans(depth)
607
-
608
- depth = np.repeat(depth[:, :, :], 3, axis=2)
609
- # Save output images
610
- frame_name = f'{str(frame_index).zfill(4)}.png'
611
- Image.fromarray(depth, 'RGB').save(os.path.join(save_frames_path_depth, frame_name))
612
- Image.fromarray(final_out_notshow, 'RGB').save(os.path.join(save_frames_path_out, frame_name))
613
-
614
- Image.fromarray(final_out_show, 'RGB').save(os.path.join(save_frames_path_outshow, frame_name))
615
-
616
- # Generate videos
617
- images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + video_name+ '_out.mp4'))
618
- images_to_video(save_frames_path_depth, os.path.join(save_path_base, image_name + video_name+ '_out.mp4'))
619
-
620
- logging.info(f"✅ Video generation completed successfully!")
621
- return os.path.join(save_path_base, image_name + video_name+ '_out.mp4'), os.path.join(save_path_base, image_name + video_name+'_depth.mp4')
 
 
 
622
 
623
 
624
  def get_image_base64(path):
@@ -631,35 +645,38 @@ def assert_input_image(input_image):
631
  if input_image is None:
632
  raise gr.Error("No image selected or uploaded!")
633
 
634
- @spaces.GPU(duration=100)
635
  def process_image(input_image_dir, source_type, is_style, save_dir):
636
- """ 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
637
- process_img_input_dir = os.path.join(save_dir, 'input_image')
638
- process_img_save_dir = os.path.join(save_dir, 'processed_img')
639
- base_name = os.path.basename(input_image_dir) # abc123.jpg
640
- name_without_ext = os.path.splitext(base_name)[0] # abc123
641
- image_name_true = name_without_ext + ".png"
642
- os.makedirs(process_img_save_dir, exist_ok=True)
643
- os.makedirs(process_img_input_dir, exist_ok=True)
644
- if source_type == "example":
645
- image = Image.open(input_image_dir)
646
- return image, source_type, image_name_true
647
- else:
648
- # input_process_model.inference(input_image, process_img_save_dir)
649
- shutil.copy(input_image_dir, process_img_input_dir)
650
- input_process_model.inference(process_img_input_dir, process_img_save_dir, is_img=True, is_video=False)
651
-
652
- files = os.listdir(os.path.join(process_img_save_dir, 'dataset/images512x512/input_image'))
653
- image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
654
- # 使用 difflib 查找相似文件名
655
- matches = difflib.get_close_matches(image_name_true, image_files, n=1, cutoff=0.1)
656
- closest_match = matches[0]
657
- imge_dir = os.path.join(process_img_save_dir, 'dataset/images512x512/input_image', closest_match)
658
- image = Image.open(imge_dir)
659
- image_name_true = closest_match
660
- return image, source_type, image_name_true # 这里替换成 处理用户上传图片的逻辑
661
-
662
- @spaces.GPU(duration=100)
 
 
 
663
  @torch.no_grad()
664
  def style_transfer(processed_image, style_prompt, cfg, strength, save_base,image_name_true):
665
  """
@@ -682,7 +699,8 @@ def style_transfer(processed_image, style_prompt, cfg, strength, save_base,image
682
  controlnet_conditioning_scale=1.5
683
  )['images'][0]
684
  trg_img_pil.save(os.path.join(save_dir, image_name_true))
685
- return trg_img_pil # 🚨 这里需要替换成你的风格转换逻辑
 
686
 
687
 
688
  def reset_flag():
@@ -827,6 +845,7 @@ def launch_gradio_app():
827
  is_styled = gr.State(value=False)
828
  working_dir = gr.State()
829
  image_name_true = gr.State()
 
830
 
831
 
832
  with gr.Row():
@@ -953,12 +972,19 @@ def launch_gradio_app():
953
  autoplay=True
954
  )
955
  def apply_style_and_mark(processed_image, style_choice, cfg, strength, working_dir, image_name_true):
956
- styled = style_transfer(processed_image, styles[style_choice], cfg, strength, working_dir, image_name_true)
957
- return styled, True
 
 
 
958
 
959
  def process_image_and_enable_style(input_image, source_type, is_styled, wd):
960
- processed_result, updated_source_type, image_name_true = process_image(input_image, source_type, is_styled, wd)
961
- return processed_result, updated_source_type, gr.update(interactive=True), gr.update(interactive=True), image_name_true
 
 
 
 
962
 
963
  processed_image_button.click(
964
  fn=prepare_working_dir,
@@ -968,18 +994,18 @@ def launch_gradio_app():
968
  ).success(
969
  fn=process_image_and_enable_style,
970
  inputs=[input_image, source_type, is_styled, working_dir],
971
- outputs=[processed_image, source_type, style_button, submit, image_name_true],
972
  queue=True
973
  )
974
  style_button.click(
975
  fn=apply_style_and_mark,
976
  inputs=[processed_image, style_choice, cfg_slider, strength_slider, working_dir, image_name_true],
977
- outputs=[style_image, is_styled]
978
  )
979
  submit.click(
980
  fn=avatar_generation,
981
  inputs=[processed_image, working_dir, video_input, source_type, is_styled, style_image, image_name_true],
982
- outputs=[output_video, output_video_1], # ⏳ 稍后展示视频
983
  queue=True
984
  )
985
 
 
3
  import sys
4
  import warnings
5
  import logging
6
+ if os.environ.get("SPACES_ZERO_GPU") is not None:
7
+ import spaces
8
+ else:
9
+ class spaces:
10
+ @staticmethod
11
+ def GPU(*decorator_args, **decorator_kwargs):
12
+ def decorator(func):
13
+ def wrapper(*args, **kwargs):
14
+ return func(*args, **kwargs)
15
+ return wrapper
16
+ return decorator
17
  import difflib
18
 
19
  # Configure logging settings
 
458
  @torch.no_grad()
459
  @spaces.GPU(duration=200)
460
  def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img, image_name_true):
461
+
462
  """
463
  Generate avatars from input images.
464
 
 
474
  mean (torch.Tensor): Mean normalization tensor.
475
  ws_avg (torch.Tensor): Latent average tensor.
476
  """
477
+ try:
478
+ if is_styled:
479
+ items = [styled_img]
480
+ else:
481
+ items = [items]
482
+ video_folder = "./demo_data/target_video"
483
+ video_name = os.path.basename(video_path_input).split(".")[0]
484
+ target_path = os.path.join(video_folder, 'data_' + video_name)
485
+ exp_base_dir = os.path.join(target_path, 'coeffs')
486
+ exp_img_base_dir = os.path.join(target_path, 'images512x512')
487
+ motion_base_dir = os.path.join(target_path, 'motions')
488
+ label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
489
+ # render_model.to(device)
490
+ # image_encoder.to(device)
491
+ # vae_triplane.to(device)
492
+ # dinov2.to(device)
493
+ # ws_avg.to(device)
494
+ # DiT_model.to(device)
495
+ # Set up face verse for amimation
496
+ if source_type == 'example':
497
+ input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs'
498
+ input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/demo_imgs'
499
+ elif source_type == 'custom':
500
+ input_img_fvid = os.path.join(save_path_base, 'processed_img/dataset/coeffs/input_image')
501
+ input_img_motion = os.path.join(save_path_base, 'processed_img/dataset/motions/input_image')
502
+ else:
503
+ raise ValueError("Wrong type")
504
+ bs = 1
505
+ sample_steps = 20
506
+ cfg_scale = 4.5
507
+ pitch_range = 0.25
508
+ yaw_range = 0.35
509
+ triplane_size = (256 * 4, 256)
510
+ latent_size = (triplane_size[0] // 8, triplane_size[1] // 8)
511
+ for chunk in tqdm(list(get_chunks(items, 1)), unit='batch'):
512
+ if bs != 1:
513
+ raise ValueError("Batch size > 1 not implemented")
514
+
515
+ image_dir = chunk[0]
516
+ image_name = os.path.splitext(image_name_true)[0]
517
+
518
+ # # image_name = os.path.splitext(os.path.basename(image_dir))[0]
519
+ # if source_type == 'custom':
520
+ # image_name = os.path.splitext(image_name_true)[0]
521
+ # else:
522
+ # image_name = os.path.splitext(os.path.basename(image_dir))[0]
523
+
524
+
525
+
526
+
527
+ dino_img, clip_image = image_process(image_dir, clip_image_processor, dino_img_processor, device)
528
+
529
+ clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
530
+ uncond_clip_feature = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
531
+ -2]
532
+ dino_feature = dinov2(dino_img).last_hidden_state
533
+ uncond_dino_feature = dinov2(torch.zeros_like(dino_img)).last_hidden_state
534
+
535
+ samples = generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature,
536
+ uncond_clip_feature, uncond_dino_feature, device, latent_size,
537
+ 'dpm-solver')
538
+
539
+ samples = (samples / 0.3994218)
540
+ samples = rearrange(samples, "b c (f h) w -> b c f h w", f=4)
541
+ samples = vae_triplane.decode(samples)
542
+ samples = rearrange(samples, "b c f h w -> b f c h w")
543
+ samples = samples * std + mean
544
+ torch.cuda.empty_cache()
545
+ torch.cuda.ipc_collect()
546
+ save_frames_path_out = os.path.join(save_path_base, image_name, video_name, 'out')
547
+ save_frames_path_outshow = os.path.join(save_path_base, image_name, video_name,'out_show')
548
+ save_frames_path_depth = os.path.join(save_path_base, image_name, video_name, 'depth')
549
+
550
+ os.makedirs(save_frames_path_out, exist_ok=True)
551
+ os.makedirs(save_frames_path_outshow, exist_ok=True)
552
+ os.makedirs(save_frames_path_depth, exist_ok=True)
553
+
554
+ img_ref = np.array(Image.open(image_dir))
555
+ img_ref_out = img_ref.copy()
556
+ img_ref = torch.from_numpy(img_ref.astype(np.float32) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0).to(device)
557
+
558
+ motion_app_dir = os.path.join(input_img_motion, image_name + '.npy')
559
+ motion_app = torch.tensor(np.load(motion_app_dir), dtype=torch.float32).unsqueeze(0).to(device)
560
+
561
+ id_motions = os.path.join(input_img_fvid, image_name + '.npy')
562
+
563
+ all_pose = json.loads(open(label_file_test).read())['labels']
564
+ all_pose = dict(all_pose)
565
+ if os.path.exists(id_motions):
566
+ coeff = np.load(id_motions).astype(np.float32)
567
+ coeff = torch.from_numpy(coeff).to(device).float().unsqueeze(0)
568
+ Faceverse.id_coeff = Faceverse.recon_model.split_coeffs(coeff)[0]
569
+ motion_dir = os.path.join(motion_base_dir, video_name)
570
+ exp_dir = os.path.join(exp_base_dir, video_name)
571
+ for frame_index, motion_name in enumerate(
572
+ tqdm(natsorted(os.listdir(motion_dir), alg=ns.PATH), desc="Processing Frames")):
573
+ exp_each_dir_img = os.path.join(exp_img_base_dir, video_name, motion_name.replace('.npy', '.png'))
574
+ exp_each_dir = os.path.join(exp_dir, motion_name)
575
+ motion_each_dir = os.path.join(motion_dir, motion_name)
576
+
577
+ # Load pose data
578
+ pose_key = os.path.join(video_name, motion_name.replace('.npy', '.png'))
579
+
580
+ cam2world_pose = LookAtPoseSampler.sample(
581
+ 3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
582
+ 3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
583
+ torch.tensor([0, 0, 0], device=device), radius=2.7, device=device)
584
+ pose_show = torch.cat([cam2world_pose.reshape(-1, 16),
585
+ FOV_to_intrinsics(fov_degrees=18.837, device=device).reshape(-1, 9)], 1).to(device)
586
+
587
+ pose = torch.tensor(np.array(all_pose[pose_key]).astype(np.float32)).float().unsqueeze(0).to(device)
588
+
589
+ # Load and resize expression image
590
+ exp_img = np.array(Image.open(exp_each_dir_img).resize((512, 512)))
591
+
592
+ # Load expression coefficients
593
+ exp_coeff = torch.from_numpy(np.load(exp_each_dir).astype(np.float32)).to(device).float().unsqueeze(0)
594
+ exp_target = Faceverse.make_driven_rendering(exp_coeff, res=256)
595
+
596
+ # Load motion data
597
+ motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device)
598
+
599
+ # img_ref_double = duplicate_batch(img_ref, batch_size=2)
600
+ # motion_app_double = duplicate_batch(motion_app, batch_size=2)
601
+ # motion_double = duplicate_batch(motion, batch_size=2)
602
+ # pose_double = torch.cat([pose_show, pose], dim=0)
603
+ # exp_target_double = duplicate_batch(exp_target, batch_size=2)
604
+ # samples_double = duplicate_batch(samples, batch_size=2)
605
+ # Select refine_net processing method
606
+ final_out = render_model(
607
+ img_ref, None, motion_app, motion, c=pose, mesh=exp_target,
608
+ triplane_recon=samples,
609
+ ws_avg=ws_avg, motion_scale=1.
610
+ )
611
+
612
+ # Process output image
613
+ final_out_show = trans(final_out['image_sr'][0].unsqueeze(0))
614
+ # final_out_notshow = trans(final_out['image_sr'][0].unsqueeze(0))
615
+ depth = final_out['image_depth'][0].unsqueeze(0)
616
+ depth = -depth
617
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 2 - 1
618
+ depth = trans(depth)
619
+
620
+ depth = np.repeat(depth[:, :, :], 3, axis=2)
621
+ # Save output images
622
+ frame_name = f'{str(frame_index).zfill(4)}.png'
623
+ Image.fromarray(depth, 'RGB').save(os.path.join(save_frames_path_depth, frame_name))
624
+ # Image.fromarray(final_out_notshow, 'RGB').save(os.path.join(save_frames_path_out, frame_name))
625
+
626
+ Image.fromarray(final_out_show, 'RGB').save(os.path.join(save_frames_path_outshow, frame_name))
627
+
628
+ # Generate videos
629
+ images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + video_name+ '_out.mp4'))
630
+ images_to_video(save_frames_path_depth, os.path.join(save_path_base, image_name + video_name+ '_depth.mp4'))
631
+
632
+ logging.info(f"✅ Video generation completed successfully!")
633
+ return os.path.join(save_path_base, image_name + video_name+ '_out.mp4'), os.path.join(save_path_base, image_name + video_name+'_depth.mp4')
634
+ except Exception as e:
635
+ return None, None, f"❌ error:{str(e)}"
636
 
637
 
638
  def get_image_base64(path):
 
645
  if input_image is None:
646
  raise gr.Error("No image selected or uploaded!")
647
 
648
+ @spaces.GPU(duration=30)
649
  def process_image(input_image_dir, source_type, is_style, save_dir):
650
+
651
+ """ 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
652
+ process_img_input_dir = os.path.join(save_dir, 'input_image')
653
+ process_img_save_dir = os.path.join(save_dir, 'processed_img')
654
+ base_name = os.path.basename(input_image_dir) # abc123.jpg
655
+ name_without_ext = os.path.splitext(base_name)[0] # abc123
656
+ image_name_true = name_without_ext + ".png"
657
+ os.makedirs(process_img_save_dir, exist_ok=True)
658
+ os.makedirs(process_img_input_dir, exist_ok=True)
659
+ if source_type == "example":
660
+ image = Image.open(input_image_dir)
661
+ return image, source_type, image_name_true, ""
662
+ else:
663
+ # input_process_model.inference(input_image, process_img_save_dir)
664
+ shutil.copy(input_image_dir, process_img_input_dir)
665
+ input_process_model.inference(process_img_input_dir, process_img_save_dir, is_img=True, is_video=False)
666
+
667
+ files = os.listdir(os.path.join(process_img_save_dir, 'dataset/images512x512/input_image'))
668
+ image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
669
+ # 使用 difflib 查找相似文件名
670
+ matches = difflib.get_close_matches(image_name_true, image_files, n=1, cutoff=0.1)
671
+ closest_match = matches[0]
672
+ imge_dir = os.path.join(process_img_save_dir, 'dataset/images512x512/input_image', closest_match)
673
+ image = Image.open(imge_dir)
674
+ image_name_true = closest_match
675
+ return image, source_type, image_name_true, "" # 这里替换成 处理用户上传图片的逻辑
676
+
677
+
678
+
679
+ @spaces.GPU(duration=30)
680
  @torch.no_grad()
681
  def style_transfer(processed_image, style_prompt, cfg, strength, save_base,image_name_true):
682
  """
 
699
  controlnet_conditioning_scale=1.5
700
  )['images'][0]
701
  trg_img_pil.save(os.path.join(save_dir, image_name_true))
702
+ return trg_img_pil, "" # 🚨 这里需要替换成你的风格转换逻辑
703
+
704
 
705
 
706
  def reset_flag():
 
845
  is_styled = gr.State(value=False)
846
  working_dir = gr.State()
847
  image_name_true = gr.State()
848
+ error_box = gr.Textbox(label="error hint", lines=3, interactive=False, visible=True)
849
 
850
 
851
  with gr.Row():
 
972
  autoplay=True
973
  )
974
  def apply_style_and_mark(processed_image, style_choice, cfg, strength, working_dir, image_name_true):
975
+ try:
976
+ styled = style_transfer(processed_image, styles[style_choice], cfg, strength, working_dir, image_name_true)
977
+ return styled, True, ""
978
+ except Exception as e:
979
+ return None, True, f"❌ error:{str(e)}"
980
 
981
  def process_image_and_enable_style(input_image, source_type, is_styled, wd):
982
+ try:
983
+ processed_result, updated_source_type, image_name_true = process_image(input_image, source_type, is_styled, wd)
984
+
985
+ return processed_result, updated_source_type, gr.update(interactive=True), gr.update(interactive=True), image_name_true, ""
986
+ except Exception as e:
987
+ return None, updated_source_type, gr.update(interactive=False), gr.update(interactive=False), image_name_true, f"❌ error:{str(e)}"
988
 
989
  processed_image_button.click(
990
  fn=prepare_working_dir,
 
994
  ).success(
995
  fn=process_image_and_enable_style,
996
  inputs=[input_image, source_type, is_styled, working_dir],
997
+ outputs=[processed_image, source_type, style_button, submit, image_name_true, error_box],
998
  queue=True
999
  )
1000
  style_button.click(
1001
  fn=apply_style_and_mark,
1002
  inputs=[processed_image, style_choice, cfg_slider, strength_slider, working_dir, image_name_true],
1003
+ outputs=[style_image, is_styled, error_box]
1004
  )
1005
  submit.click(
1006
  fn=avatar_generation,
1007
  inputs=[processed_image, working_dir, video_input, source_type, is_styled, style_image, image_name_true],
1008
+ outputs=[output_video, output_video_1, error_box], # ⏳ 稍后展示视频
1009
  queue=True
1010
  )
1011