avans06 commited on
Commit
645bf54
·
1 Parent(s): b4e0c45

Download output is now in zip file format.

Browse files
Files changed (1) hide show
  1. app.py +35 -14
app.py CHANGED
@@ -10,6 +10,7 @@ import math
10
  import time
11
  import ast
12
  import argparse
 
13
  from collections import defaultdict
14
  from facexlib.utils.misc import download_from_url
15
  from basicsr.utils.realesrganer import RealESRGANer
@@ -809,6 +810,17 @@ class Upscale:
809
  self.modelInUse = ""
810
 
811
  files = []
 
 
 
 
 
 
 
 
 
 
 
812
  is_auto_split_upscale = True
813
  # Dictionary to track counters for each filename
814
  name_counters = defaultdict(int)
@@ -820,10 +832,10 @@ class Upscale:
820
  # Increment the counter for the current name
821
  name_counters[img_name] += 1
822
  if name_counters[img_name] > 1:
823
- basename = f"{basename}_{name_counters[img_name]:02d}"
824
-
825
  img_cv2 = cv2.imdecode(np.fromfile(img_path, np.uint8), cv2.IMREAD_UNCHANGED) # numpy.ndarray
826
-
827
  img_mode = "RGBA" if len(img_cv2.shape) == 3 and img_cv2.shape[2] == 4 else None
828
  if len(img_cv2.shape) == 2: # for gray inputs
829
  img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_GRAY2BGR)
@@ -835,30 +847,33 @@ class Upscale:
835
  current_progress += progressRatio/progressTotal;
836
  progress(current_progress, desc=f"image{gallery_idx:02d}, Background upscale Section")
837
  timer.checkpoint(f"image{gallery_idx:02d}, Background upscale Section")
838
-
839
  if self.face_enhancer:
840
  cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img_cv2, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
841
  # save faces
842
  if cropped_faces and restored_aligned:
843
  for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)):
844
  # save cropped face
845
- save_crop_path = f"output/{basename}{idx:02d}_cropped_faces{self.modelInUse}.png"
846
  self.imwriteUTF8(save_crop_path, cropped_face)
 
847
  # save restored face
848
- save_restore_path = f"output/{basename}{idx:02d}_restored_faces{self.modelInUse}.png"
849
  self.imwriteUTF8(save_restore_path, restored_face)
 
850
  # save comparison image
851
- save_cmp_path = f"output/{basename}{idx:02d}_cmp{self.modelInUse}.png"
852
  cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
853
  self.imwriteUTF8(save_cmp_path, cmp_img)
854
-
 
855
  files.append(save_crop_path)
856
  files.append(save_restore_path)
857
  files.append(save_cmp_path)
858
  current_progress += progressRatio/progressTotal;
859
  progress(current_progress, desc=f"image{gallery_idx:02d}, Face enhancer Section")
860
  timer.checkpoint(f"image{gallery_idx:02d}, Face enhancer Section")
861
-
862
  restored_img = bg_upsample_img
863
  timer.report()
864
 
@@ -866,15 +881,21 @@ class Upscale:
866
  extension = ".png" if img_mode == "RGBA" else ".jpg" # RGBA images should be saved in png format
867
  save_path = f"output/{basename}{self.modelInUse}{extension}"
868
  self.imwriteUTF8(save_path, restored_img)
 
869
 
870
  restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
871
  files.append(save_path)
872
  except RuntimeError as error:
873
  print(traceback.format_exc())
874
  print('Error', error)
875
-
876
  progress(1, desc=f"Execution completed")
877
  timer.report_all() # Print all recorded times
 
 
 
 
 
878
  except Exception as error:
879
  print(traceback.format_exc())
880
  print("global exception: ", error)
@@ -886,8 +907,8 @@ class Upscale:
886
  # Free GPU memory and clean up resources
887
  torch.cuda.empty_cache()
888
  gc.collect()
889
-
890
- return files, files
891
 
892
 
893
  def find_max_numbers(self, state_dict, findkeys):
@@ -898,7 +919,7 @@ class Upscale:
898
 
899
  for key in state_dict:
900
  for findkey, pattern in patterns.items():
901
- if match := pattern.match(key):
902
  num = int(match.group(1))
903
  max_values[findkey] = max(num, max_values[findkey] if max_values[findkey] is not None else num)
904
 
@@ -1186,7 +1207,7 @@ def main():
1186
  ], variant="secondary", size="lg",)
1187
  with gr.Column(variant="panel"):
1188
  gallerys = gr.Gallery(type="filepath", label="Output (The whole image)", format="png")
1189
- outputs = gr.File(label="Download the output image")
1190
  with gr.Row(variant="panel"):
1191
  # Generate output array
1192
  output_arr = []
 
10
  import time
11
  import ast
12
  import argparse
13
+ import zipfile
14
  from collections import defaultdict
15
  from facexlib.utils.misc import download_from_url
16
  from basicsr.utils.realesrganer import RealESRGANer
 
810
  self.modelInUse = ""
811
 
812
  files = []
813
+ # Create zip files for each output type
814
+ unique_id = str(int(time.time())) # Use timestamp for uniqueness
815
+ zip_cropf_path = f"output/{unique_id}_cropped_faces{self.modelInUse}.zip"
816
+ zipf_cropf = zipfile.ZipFile(zip_cropf_path, 'w', zipfile.ZIP_DEFLATED)
817
+ zip_restoref_path = f"output/{unique_id}_restored_faces{self.modelInUse}.zip"
818
+ zipf_restoref = zipfile.ZipFile(zip_restoref_path, 'w', zipfile.ZIP_DEFLATED)
819
+ zip_cmp_path = f"output/{unique_id}_cmp{self.modelInUse}.zip"
820
+ zipf_cmp = zipfile.ZipFile(zip_cmp_path, 'w', zipfile.ZIP_DEFLATED)
821
+ zip_restore_path = f"output/{unique_id}_restored_images{self.modelInUse}.zip"
822
+ zipf_restore = zipfile.ZipFile(zip_restore_path, 'w', zipfile.ZIP_DEFLATED)
823
+
824
  is_auto_split_upscale = True
825
  # Dictionary to track counters for each filename
826
  name_counters = defaultdict(int)
 
832
  # Increment the counter for the current name
833
  name_counters[img_name] += 1
834
  if name_counters[img_name] > 1:
835
+ basename = f"{basename}_{name_counters[img_name] - 1:02d}"
836
+
837
  img_cv2 = cv2.imdecode(np.fromfile(img_path, np.uint8), cv2.IMREAD_UNCHANGED) # numpy.ndarray
838
+
839
  img_mode = "RGBA" if len(img_cv2.shape) == 3 and img_cv2.shape[2] == 4 else None
840
  if len(img_cv2.shape) == 2: # for gray inputs
841
  img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_GRAY2BGR)
 
847
  current_progress += progressRatio/progressTotal;
848
  progress(current_progress, desc=f"image{gallery_idx:02d}, Background upscale Section")
849
  timer.checkpoint(f"image{gallery_idx:02d}, Background upscale Section")
850
+
851
  if self.face_enhancer:
852
  cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img_cv2, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
853
  # save faces
854
  if cropped_faces and restored_aligned:
855
  for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)):
856
  # save cropped face
857
+ save_crop_path = f"output/{basename}_{idx:02d}_cropped_faces{self.modelInUse}.png"
858
  self.imwriteUTF8(save_crop_path, cropped_face)
859
+ zipf_cropf.write(save_crop_path, arcname=os.path.basename(save_crop_path))
860
  # save restored face
861
+ save_restore_path = f"output/{basename}_{idx:02d}_restored_faces{self.modelInUse}.png"
862
  self.imwriteUTF8(save_restore_path, restored_face)
863
+ zipf_restoref.write(save_restore_path, arcname=os.path.basename(save_restore_path))
864
  # save comparison image
865
+ save_cmp_path = f"output/{basename}_{idx:02d}_cmp{self.modelInUse}.png"
866
  cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
867
  self.imwriteUTF8(save_cmp_path, cmp_img)
868
+ zipf_cmp.write(save_cmp_path, arcname=os.path.basename(save_cmp_path))
869
+
870
  files.append(save_crop_path)
871
  files.append(save_restore_path)
872
  files.append(save_cmp_path)
873
  current_progress += progressRatio/progressTotal;
874
  progress(current_progress, desc=f"image{gallery_idx:02d}, Face enhancer Section")
875
  timer.checkpoint(f"image{gallery_idx:02d}, Face enhancer Section")
876
+
877
  restored_img = bg_upsample_img
878
  timer.report()
879
 
 
881
  extension = ".png" if img_mode == "RGBA" else ".jpg" # RGBA images should be saved in png format
882
  save_path = f"output/{basename}{self.modelInUse}{extension}"
883
  self.imwriteUTF8(save_path, restored_img)
884
+ zipf_restore.write(save_path, arcname=os.path.basename(save_path))
885
 
886
  restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
887
  files.append(save_path)
888
  except RuntimeError as error:
889
  print(traceback.format_exc())
890
  print('Error', error)
891
+
892
  progress(1, desc=f"Execution completed")
893
  timer.report_all() # Print all recorded times
894
+ # Close zip files
895
+ zipf_cropf.close()
896
+ zipf_restoref.close()
897
+ zipf_cmp.close()
898
+ zipf_restore.close()
899
  except Exception as error:
900
  print(traceback.format_exc())
901
  print("global exception: ", error)
 
907
  # Free GPU memory and clean up resources
908
  torch.cuda.empty_cache()
909
  gc.collect()
910
+
911
+ return files, [zip_cropf_path, zip_restoref_path, zip_cmp_path, zip_restore_path]
912
 
913
 
914
  def find_max_numbers(self, state_dict, findkeys):
 
919
 
920
  for key in state_dict:
921
  for findkey, pattern in patterns.items():
922
+ if match := pattern.match(key):
923
  num = int(match.group(1))
924
  max_values[findkey] = max(num, max_values[findkey] if max_values[findkey] is not None else num)
925
 
 
1207
  ], variant="secondary", size="lg",)
1208
  with gr.Column(variant="panel"):
1209
  gallerys = gr.Gallery(type="filepath", label="Output (The whole image)", format="png")
1210
+ outputs = gr.File(label="Download the output ZIP file")
1211
  with gr.Row(variant="panel"):
1212
  # Generate output array
1213
  output_arr = []