Deadmon commited on
Commit
1c7c2cf
·
verified ·
1 Parent(s): 50c16d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import time
4
+ from gradio_client import Client, handle_file
5
+ import zipfile
6
+ import gradio as gr
7
+
8
+ # Configuration
9
+ INPUT_DIR = "input_images" # Folder with original images
10
+ OUTPUT_DIR = "output_images" # Base output folder
11
+ DATASET_DIR = os.path.join(OUTPUT_DIR, "dataset") # Subfolder for organized dataset
12
+ ZIP_FILE = os.path.join(OUTPUT_DIR, "dataset.zip") # Path for the output ZIP file
13
+ TARGET_SIZE = (512, 512) # Target size for Stable Diffusion (SD 1.5)
14
+ HUGGINGFACE_SPACE_URL = "bdsqlsz/Florence-2-SD3-Captioner"
15
+
16
+ # Ensure output and dataset directories exist
17
+ os.makedirs(DATASET_DIR, exist_ok=True)
18
+
19
+ def resize_and_crop_image(input_path, output_path, target_size):
20
+ """Resize and crop an image to the target size while preserving aspect ratio."""
21
+ try:
22
+ img = Image.open(input_path).convert("RGB")
23
+ width, height = img.size
24
+ target_width, target_height = target_size
25
+ img_ratio = width / height
26
+ target_ratio = target_width / target_height
27
+
28
+ if img_ratio > target_ratio:
29
+ new_height = target_height
30
+ new_width = int(new_height * img_ratio)
31
+ else:
32
+ new_width = target_width
33
+ new_height = int(new_width / img_ratio)
34
+
35
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
36
+ left = (new_width - target_width) // 2
37
+ top = (new_height - target_height) // 2
38
+ right = left + target_width
39
+ bottom = top + target_height
40
+ img = img.crop((left, top, right, bottom))
41
+ img.save(output_path, "JPEG", quality=95)
42
+ return True
43
+ except Exception as e:
44
+ print(f"Error processing {input_path}: {e}")
45
+ return False
46
+
47
+ def get_caption_from_florence(image_path):
48
+ """Call the Florence-2-SD3-Captioner /process_image endpoint via Gradio API."""
49
+ try:
50
+ client = Client(HUGGINGFACE_SPACE_URL)
51
+ result = client.predict(
52
+ image=handle_file(image_path),
53
+ api_name="/process_image"
54
+ )
55
+ return result if isinstance(result, str) else "No caption returned"
56
+ except Exception as e:
57
+ print(f"Error captioning {image_path}: {e}")
58
+ return "Captioning failed"
59
+
60
+ def create_zip_file():
61
+ """Create a ZIP file of the dataset folder."""
62
+ with zipfile.ZipFile(ZIP_FILE, 'w', zipfile.ZIP_DEFLATED) as zipf:
63
+ for root, _, files in os.walk(DATASET_DIR):
64
+ for file in files:
65
+ file_path = os.path.join(root, file)
66
+ arcname = os.path.relpath(file_path, OUTPUT_DIR)
67
+ zipf.write(file_path, arcname)
68
+ return ZIP_FILE
69
+
70
+ def process_images():
71
+ """Process all images and return status."""
72
+ if not os.path.exists(INPUT_DIR):
73
+ return f"Input directory '{INPUT_DIR}' not found."
74
+
75
+ image_files = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
76
+ if not image_files:
77
+ return f"No images found in '{INPUT_DIR}'."
78
+
79
+ for idx, filename in enumerate(image_files, 1):
80
+ input_path = os.path.join(INPUT_DIR, filename)
81
+ base_name = f"img{idx}"
82
+ output_image_path = os.path.join(DATASET_DIR, f"{base_name}.jpg")
83
+ caption_file_path = os.path.join(DATASET_DIR, f"{base_name}.txt")
84
+
85
+ print(f"Processing {idx}/{len(image_files)}: {filename}")
86
+
87
+ if resize_and_crop_image(input_path, output_image_path, TARGET_SIZE):
88
+ caption = get_caption_from_florence(output_image_path)
89
+ print(f"Caption: {caption}")
90
+ with open(caption_file_path, "w", encoding="utf-8") as f:
91
+ f.write(caption)
92
+ else:
93
+ print(f"Skipping captioning for {filename} due to processing error.")
94
+
95
+ time.sleep(1) # Avoid overwhelming the API
96
+
97
+ # Create ZIP file after processing
98
+ zip_path = create_zip_file()
99
+ return f"Processing complete! ZIP file created at {zip_path}"
100
+
101
+ def launch_interface():
102
+ """Launch Gradio interface with a download button after processing."""
103
+ status = process_images()
104
+
105
+ with gr.Blocks(title="Image Processing and Download") as demo:
106
+ gr.Markdown("### Image Processing Status")
107
+ status_text = gr.Textbox(value=status, label="Status", interactive=False)
108
+
109
+ if "Processing complete" in status:
110
+ gr.Markdown("### Download Your Dataset")
111
+ download_button = gr.File(label="Download ZIP", value=ZIP_FILE)
112
+ else:
113
+ gr.Markdown("No ZIP file available due to processing errors.")
114
+
115
+ demo.launch()
116
+
117
+ if __name__ == "__main__":
118
+ print("Starting image processing and captioning...")
119
+ launch_interface()