刘虹雨 commited on
Commit
7492542
·
1 Parent(s): 7fa0d80

update code

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -4,6 +4,7 @@ import sys
4
  import warnings
5
  import logging
6
  import spaces
 
7
 
8
  # Configure logging settings
9
  logging.basicConfig(
@@ -635,7 +636,9 @@ def process_image(input_image_dir, source_type, is_style, save_dir):
635
  """ 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
636
  process_img_input_dir = os.path.join(save_dir, 'input_image')
637
  process_img_save_dir = os.path.join(save_dir, 'processed_img')
638
- image_name_true = os.path.basename(input_image_dir)
 
 
639
  os.makedirs(process_img_save_dir, exist_ok=True)
640
  os.makedirs(process_img_input_dir, exist_ok=True)
641
  if source_type == "example":
@@ -645,8 +648,15 @@ def process_image(input_image_dir, source_type, is_style, save_dir):
645
  # input_process_model.inference(input_image, process_img_save_dir)
646
  shutil.copy(input_image_dir, process_img_input_dir)
647
  input_process_model.inference(process_img_input_dir, process_img_save_dir, is_img=True, is_video=False)
648
- imge_dir = os.path.join(save_dir, 'processed_img/dataset/images512x512/input_image', image_name_true)
 
 
 
 
 
 
649
  image = Image.open(imge_dir)
 
650
  return image, source_type, image_name_true # 这里替换成 处理用户上传图片的逻辑
651
 
652
  @spaces.GPU(duration=100)
 
4
  import warnings
5
  import logging
6
  import spaces
7
+ import difflib
8
 
9
  # Configure logging settings
10
  logging.basicConfig(
 
636
  """ 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
637
  process_img_input_dir = os.path.join(save_dir, 'input_image')
638
  process_img_save_dir = os.path.join(save_dir, 'processed_img')
639
+ base_name = os.path.basename(input_image_dir) # abc123.jpg
640
+ name_without_ext = os.path.splitext(base_name)[0] # abc123
641
+ image_name_true = name_without_ext + ".png"
642
  os.makedirs(process_img_save_dir, exist_ok=True)
643
  os.makedirs(process_img_input_dir, exist_ok=True)
644
  if source_type == "example":
 
648
  # input_process_model.inference(input_image, process_img_save_dir)
649
  shutil.copy(input_image_dir, process_img_input_dir)
650
  input_process_model.inference(process_img_input_dir, process_img_save_dir, is_img=True, is_video=False)
651
+
652
+ files = os.listdir(os.path.join(process_img_save_dir, 'dataset/images512x512/input_image'))
653
+ image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
654
+ # 使用 difflib 查找相似文件名
655
+ matches = difflib.get_close_matches(image_name_true, image_files, n=1, cutoff=0.1)
656
+ closest_match = matches[0]
657
+ imge_dir = os.path.join(process_img_save_dir, 'processed_img/dataset/images512x512/input_image', closest_match)
658
  image = Image.open(imge_dir)
659
+ image_name_true = closest_match
660
  return image, source_type, image_name_true # 这里替换成 处理用户上传图片的逻辑
661
 
662
  @spaces.GPU(duration=100)