Spaces:
Running
Running
Rishi Desai
commited on
Commit
·
dc6215b
1
Parent(s):
ca96bd8
init dump for demo
Browse files- caption.py +110 -44
- demo.py +378 -0
- main.py +54 -39
- requirements.txt +5 -0
caption.py
CHANGED
@@ -10,7 +10,15 @@ load_dotenv()
|
|
10 |
def get_prompt():
|
11 |
return """Automated Image Captioning (for LoRA Training)
|
12 |
|
13 |
-
Role: You are an expert AI captioning system generating precise, structured descriptions for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
General Guidelines:
|
16 |
1. Prioritize Consistency – Maintain uniform descriptions across all images in a dataset. Avoid introducing variation in features that should remain constant (e.g., fixed traits like eye color, hair color, or markings that are inherently part of the concept and handled during model training).
|
@@ -28,8 +36,8 @@ Avoid Describing These Unless Variable Across Dataset or Uncertain from Concept:
|
|
28 |
- Tattoos or markings if core to the concept
|
29 |
- Known accessories that always appear (unless outfit-specific)
|
30 |
|
31 |
-
|
32 |
-
tr1gger [Style], [Notable Visual Features], [Clothing], [Pose], [Expression], [Lighting], [Camera Angle]
|
33 |
|
34 |
Captioning Principles:
|
35 |
- Emphasize visual variation and context-specific details (outfit, pose, lighting, expression, camera angle).
|
@@ -41,14 +49,20 @@ Captioning Principles:
|
|
41 |
- Avoid mentioning real or fictional identities.
|
42 |
- Always prefix with the trigger word "tr1gger."
|
43 |
|
44 |
-
Examples:
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
"""
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
# Convert PIL images to base64 encoded strings
|
53 |
image_strings = []
|
54 |
for image in images:
|
@@ -66,30 +80,103 @@ def caption_images(images):
|
|
66 |
client = Together(api_key=api_key)
|
67 |
captions = []
|
68 |
|
69 |
-
#
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
messages = [
|
72 |
{"role": "system", "content": get_prompt()},
|
73 |
-
{
|
74 |
-
"role": "user",
|
75 |
-
"content": [
|
76 |
-
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}},
|
77 |
-
{"type": "text", "text": "Describe this image."}
|
78 |
-
]
|
79 |
-
}
|
80 |
]
|
81 |
|
82 |
-
# Request caption for the image using Llama 4 Maverick
|
83 |
response = client.chat.completions.create(
|
84 |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
85 |
messages=messages
|
86 |
)
|
87 |
|
88 |
-
#
|
89 |
full_response = response.choices[0].message.content.strip()
|
90 |
-
|
91 |
-
|
92 |
-
captions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
return captions
|
95 |
|
@@ -100,24 +187,3 @@ def extract_captions(file_path):
|
|
100 |
if line.startswith("tr1gger"):
|
101 |
captions.append(line.strip())
|
102 |
return captions
|
103 |
-
|
104 |
-
# Example usage
|
105 |
-
if __name__ == "__main__":
|
106 |
-
if not os.environ.get("TOGETHER_API_KEY"):
|
107 |
-
print("Please update the environment with your Together AI API key.")
|
108 |
-
exit(1)
|
109 |
-
|
110 |
-
# Load images
|
111 |
-
image_paths = ['input/daenyrs_hd.jpg', 'input/girl_body.png']
|
112 |
-
images = [Image.open(path).convert("RGB") for path in image_paths]
|
113 |
-
|
114 |
-
# Generate captions
|
115 |
-
captions = caption_images(images)
|
116 |
-
for i, caption in enumerate(captions):
|
117 |
-
print(f"Generated Caption for Image {i+1}: {caption}")
|
118 |
-
|
119 |
-
# Extract captions from a file
|
120 |
-
file_path = 'post_girl/multiview_0.txt'
|
121 |
-
extracted_captions = extract_captions(file_path)
|
122 |
-
for caption in extracted_captions:
|
123 |
-
print(caption)
|
|
|
10 |
def get_prompt():
|
11 |
return """Automated Image Captioning (for LoRA Training)
|
12 |
|
13 |
+
Role: You are an expert AI captioning system generating precise, structured descriptions for character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.
|
14 |
+
|
15 |
+
IMPORTANT: You MUST follow these rules EXACTLY:
|
16 |
+
1. EVERY caption MUST start with the word "tr1gger" (exactly like this, no variations)
|
17 |
+
2. You MUST use the exact format: tr1gger [Style], [Notable Visual Features], [Clothing], [Pose], [Expression], [Lighting], [Camera Angle]
|
18 |
+
3. DO NOT include any additional text, explanations, or formatting
|
19 |
+
4. DO NOT use bullet points, lists, or any other formatting
|
20 |
+
5. DO NOT include any text before or after the caption
|
21 |
+
6. If you don't follow this format exactly, the caption will be rejected
|
22 |
|
23 |
General Guidelines:
|
24 |
1. Prioritize Consistency – Maintain uniform descriptions across all images in a dataset. Avoid introducing variation in features that should remain constant (e.g., fixed traits like eye color, hair color, or markings that are inherently part of the concept and handled during model training).
|
|
|
36 |
- Tattoos or markings if core to the concept
|
37 |
- Known accessories that always appear (unless outfit-specific)
|
38 |
|
39 |
+
Caption Format (MUST FOLLOW EXACTLY):
|
40 |
+
tr1gger [Style], [Notable Visual Features], [Clothing], [Pose], [Expression/Mood], [Background/Setting], [Lighting], [Camera Angle]
|
41 |
|
42 |
Captioning Principles:
|
43 |
- Emphasize visual variation and context-specific details (outfit, pose, lighting, expression, camera angle).
|
|
|
49 |
- Avoid mentioning real or fictional identities.
|
50 |
- Always prefix with the trigger word "tr1gger."
|
51 |
|
52 |
+
Examples (MUST FOLLOW THIS EXACT FORMAT):
|
53 |
+
tr1gger photorealistic, combat gear, tactical vest and gloves, standing in profile, neutral, empty room, overcast lighting, side profile
|
54 |
+
tr1gger 3D-rendered, digital patterns, hooded cloak, seated cross-legged, calm, meditation chamber, low ambient lighting, front view
|
55 |
+
tr1gger anime-style, school uniform with blue necktie, standing with arms behind back, gentle smile, classroom, soft daylight, three-quarter view
|
56 |
+
tr1gger photorealistic, long trench coat and combat boots, walking, determined, rain-soaked street, dramatic shadows, low-angle view
|
57 |
+
|
58 |
+
REMEMBER: Your response must be a single line starting with "tr1gger" and following the exact format above. No additional text, formatting, or explanations are allowed.
|
59 |
"""
|
60 |
|
61 |
+
class CaptioningError(Exception):
|
62 |
+
"""Exception raised for errors in the captioning process."""
|
63 |
+
pass
|
64 |
+
|
65 |
+
def caption_images(images, category=None, batch_mode=False):
|
66 |
# Convert PIL images to base64 encoded strings
|
67 |
image_strings = []
|
68 |
for image in images:
|
|
|
80 |
client = Together(api_key=api_key)
|
81 |
captions = []
|
82 |
|
83 |
+
# If batch_mode is True, process all images in a single API call
|
84 |
+
if batch_mode and category:
|
85 |
+
# Create a content array with all images
|
86 |
+
content = [{"type": "text", "text": f"Here is the batch of images for {category}. Please caption each image on a separate line, starting each caption with 'tr1gger'."}]
|
87 |
+
|
88 |
+
# Add all images to the content array
|
89 |
+
for i, img_str in enumerate(image_strings):
|
90 |
+
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})
|
91 |
+
content.append({"type": "text", "text": f"Image {i+1}"})
|
92 |
+
|
93 |
+
# Send the batch request
|
94 |
messages = [
|
95 |
{"role": "system", "content": get_prompt()},
|
96 |
+
{"role": "user", "content": content}
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
]
|
98 |
|
|
|
99 |
response = client.chat.completions.create(
|
100 |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
101 |
messages=messages
|
102 |
)
|
103 |
|
104 |
+
# Parse the response to extract captions for each image
|
105 |
full_response = response.choices[0].message.content.strip()
|
106 |
+
lines = full_response.splitlines()
|
107 |
+
|
108 |
+
# Extract captions from the response
|
109 |
+
image_count = len(image_strings)
|
110 |
+
captions = [""] * image_count # Initialize with empty strings
|
111 |
+
|
112 |
+
# Extract lines that start with or contain "tr1gger"
|
113 |
+
tr1gger_lines = [line for line in lines if "tr1gger" in line]
|
114 |
+
|
115 |
+
# Assign captions to images
|
116 |
+
for i in range(image_count):
|
117 |
+
if i < len(tr1gger_lines):
|
118 |
+
caption = tr1gger_lines[i]
|
119 |
+
# If caption contains but doesn't start with tr1gger, extract just that part
|
120 |
+
if not caption.startswith("tr1gger") and "tr1gger" in caption:
|
121 |
+
caption = caption[caption.index("tr1gger"):]
|
122 |
+
captions[i] = caption
|
123 |
+
|
124 |
+
# Check if all captions are empty or don't contain the trigger word
|
125 |
+
valid_captions = [c for c in captions if c and "tr1gger" in c]
|
126 |
+
if not valid_captions:
|
127 |
+
error_msg = "Failed to parse any valid captions from batch response. Response contained no lines with 'tr1gger'"
|
128 |
+
error_msg += f"\n\nActual response:\n{full_response}"
|
129 |
+
raise CaptioningError(error_msg)
|
130 |
+
|
131 |
+
# Check if some captions are missing
|
132 |
+
if len(valid_captions) < len(images):
|
133 |
+
missing_count = len(images) - len(valid_captions)
|
134 |
+
invalid_captions = [(i, c) for i, c in enumerate(captions) if not c or "tr1gger" not in c]
|
135 |
+
error_msg = f"Failed to parse captions for {missing_count} of {len(images)} images in batch mode"
|
136 |
+
error_msg += "\n\nMalformed captions:"
|
137 |
+
for idx, caption in invalid_captions:
|
138 |
+
error_msg += f"\nImage {idx+1}: '{caption}'"
|
139 |
+
raise CaptioningError(error_msg)
|
140 |
+
else:
|
141 |
+
# Original method: process each image separately
|
142 |
+
for img_str in image_strings:
|
143 |
+
messages = [
|
144 |
+
{"role": "system", "content": get_prompt()},
|
145 |
+
{
|
146 |
+
"role": "user",
|
147 |
+
"content": [
|
148 |
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}},
|
149 |
+
{"type": "text", "text": "Describe this image."}
|
150 |
+
]
|
151 |
+
}
|
152 |
+
]
|
153 |
+
|
154 |
+
# Request caption for the image using Llama 4 Maverick
|
155 |
+
response = client.chat.completions.create(
|
156 |
+
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
157 |
+
messages=messages
|
158 |
+
)
|
159 |
+
|
160 |
+
# Extract caption from the response
|
161 |
+
full_response = response.choices[0].message.content.strip()
|
162 |
+
# Post-process to extract only the caption part
|
163 |
+
caption = ""
|
164 |
+
for line in full_response.splitlines():
|
165 |
+
if "tr1gger" in line:
|
166 |
+
# If caption has a numbered prefix, extract just the caption part
|
167 |
+
if not line.startswith("tr1gger"):
|
168 |
+
caption = line[line.index("tr1gger"):]
|
169 |
+
else:
|
170 |
+
caption = line
|
171 |
+
break
|
172 |
+
|
173 |
+
# Check if caption is valid
|
174 |
+
if not caption:
|
175 |
+
error_msg = "Failed to extract a valid caption (containing 'tr1gger') from the response"
|
176 |
+
error_msg += f"\n\nActual response:\n{full_response}"
|
177 |
+
raise CaptioningError(error_msg)
|
178 |
+
|
179 |
+
captions.append(caption)
|
180 |
|
181 |
return captions
|
182 |
|
|
|
187 |
if line.startswith("tr1gger"):
|
188 |
captions.append(line.strip())
|
189 |
return captions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import zipfile
|
4 |
+
from io import BytesIO
|
5 |
+
import PIL.Image
|
6 |
+
import time
|
7 |
+
import tempfile
|
8 |
+
from main import process_images, collect_images_by_category, write_captions # Import the CLI functions
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# Load environment variables
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
# Maximum number of images
|
16 |
+
MAX_IMAGES = 30
|
17 |
+
|
18 |
+
def create_download_file(image_paths, captions):
|
19 |
+
"""Create a zip file with images and their captions"""
|
20 |
+
zip_io = BytesIO()
|
21 |
+
with zipfile.ZipFile(zip_io, 'w') as zip_file:
|
22 |
+
for i, (image_path, caption) in enumerate(zip(image_paths, captions)):
|
23 |
+
# Get original filename without extension
|
24 |
+
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
25 |
+
img_name = f"{base_name}.png"
|
26 |
+
caption_name = f"{base_name}.txt"
|
27 |
+
|
28 |
+
# Add image to zip
|
29 |
+
with open(image_path, 'rb') as img_file:
|
30 |
+
zip_file.writestr(img_name, img_file.read())
|
31 |
+
|
32 |
+
# Add caption to zip
|
33 |
+
zip_file.writestr(caption_name, caption)
|
34 |
+
|
35 |
+
return zip_io.getvalue()
|
36 |
+
|
37 |
+
def process_uploaded_images(image_paths, batch_by_category=False):
|
38 |
+
"""Process uploaded images using the same code path as CLI"""
|
39 |
+
try:
|
40 |
+
print(f"Processing {len(image_paths)} images, batch_by_category={batch_by_category}")
|
41 |
+
# Create a temporary directory to store the images
|
42 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
43 |
+
# Copy images to temp directory and maintain original order
|
44 |
+
temp_image_paths = []
|
45 |
+
original_to_temp = {} # Map original paths to temp paths
|
46 |
+
for path in image_paths:
|
47 |
+
filename = os.path.basename(path)
|
48 |
+
temp_path = os.path.join(temp_dir, filename)
|
49 |
+
with open(path, 'rb') as src, open(temp_path, 'wb') as dst:
|
50 |
+
dst.write(src.read())
|
51 |
+
temp_image_paths.append(temp_path)
|
52 |
+
original_to_temp[path] = temp_path
|
53 |
+
|
54 |
+
print(f"Created {len(temp_image_paths)} temporary files")
|
55 |
+
|
56 |
+
# Convert temp_dir to Path object for collect_images_by_category
|
57 |
+
temp_dir_path = Path(temp_dir)
|
58 |
+
|
59 |
+
# Process images using the CLI code path
|
60 |
+
images_by_category, image_paths_by_category = collect_images_by_category(temp_dir_path)
|
61 |
+
print(f"Collected images into {len(images_by_category)} categories")
|
62 |
+
|
63 |
+
# Get all images and paths in the correct order
|
64 |
+
all_images = []
|
65 |
+
all_image_paths = []
|
66 |
+
for path in image_paths: # Use original order
|
67 |
+
temp_path = original_to_temp[path]
|
68 |
+
found = False
|
69 |
+
for category, paths in image_paths_by_category.items():
|
70 |
+
if temp_path in [str(p) for p in paths]: # Convert Path objects to strings for comparison
|
71 |
+
idx = [str(p) for p in paths].index(temp_path)
|
72 |
+
all_images.append(images_by_category[category][idx])
|
73 |
+
all_image_paths.append(path) # Use original path
|
74 |
+
found = True
|
75 |
+
break
|
76 |
+
if not found:
|
77 |
+
print(f"Warning: Could not find image {path} in categorized data")
|
78 |
+
|
79 |
+
print(f"Collected {len(all_images)} images in correct order")
|
80 |
+
|
81 |
+
# Process based on batch setting
|
82 |
+
if batch_by_category:
|
83 |
+
# Process each category separately
|
84 |
+
captions = [""] * len(image_paths) # Initialize with empty strings
|
85 |
+
for category, images in images_by_category.items():
|
86 |
+
category_paths = image_paths_by_category[category]
|
87 |
+
print(f"Processing category '{category}' with {len(images)} images")
|
88 |
+
# Use the same code path as CLI
|
89 |
+
from caption import caption_images
|
90 |
+
category_captions = caption_images(images, category=category, batch_mode=True)
|
91 |
+
print(f"Generated {len(category_captions)} captions for category '{category}'")
|
92 |
+
print("Category captions:", category_captions) # Debug print category captions
|
93 |
+
|
94 |
+
# Map captions back to original paths
|
95 |
+
for temp_path, caption in zip(category_paths, category_captions):
|
96 |
+
temp_path_str = str(temp_path)
|
97 |
+
for orig_path, orig_temp in original_to_temp.items():
|
98 |
+
if orig_temp == temp_path_str:
|
99 |
+
idx = image_paths.index(orig_path)
|
100 |
+
captions[idx] = caption
|
101 |
+
break
|
102 |
+
else:
|
103 |
+
# Process all images at once
|
104 |
+
from caption import caption_images
|
105 |
+
print(f"Processing all {len(all_images)} images at once")
|
106 |
+
all_captions = caption_images(all_images, batch_mode=False)
|
107 |
+
print(f"Generated {len(all_captions)} captions")
|
108 |
+
print("All captions:", all_captions) # Debug print all captions
|
109 |
+
captions = [""] * len(image_paths)
|
110 |
+
for path, caption in zip(all_image_paths, all_captions):
|
111 |
+
idx = image_paths.index(path)
|
112 |
+
captions[idx] = caption
|
113 |
+
|
114 |
+
print(f"Returning {len(captions)} captions")
|
115 |
+
print("Final captions:", captions) # Debug print final captions
|
116 |
+
return captions
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
print(f"Error in processing: {e}")
|
120 |
+
raise
|
121 |
+
|
122 |
+
# Main Gradio interface
|
123 |
+
with gr.Blocks() as demo:
|
124 |
+
gr.Markdown("# Image Autocaptioner")
|
125 |
+
|
126 |
+
# Store uploaded images
|
127 |
+
stored_image_paths = gr.State([])
|
128 |
+
batch_by_category = gr.State(True) # State to track if batch by category is enabled
|
129 |
+
|
130 |
+
# Upload component
|
131 |
+
with gr.Row():
|
132 |
+
with gr.Column(scale=2):
|
133 |
+
gr.Markdown("### Upload your images")
|
134 |
+
image_upload = gr.File(
|
135 |
+
file_count="multiple",
|
136 |
+
label="Drop your files here",
|
137 |
+
file_types=["image"],
|
138 |
+
type="filepath"
|
139 |
+
)
|
140 |
+
|
141 |
+
with gr.Column(scale=1):
|
142 |
+
autocaption_btn = gr.Button("Autocaption Images", variant="primary", interactive=False)
|
143 |
+
status_text = gr.Markdown("Upload images to begin", visible=True)
|
144 |
+
|
145 |
+
# Advanced settings dropdown
|
146 |
+
with gr.Accordion("Advanced", open=False):
|
147 |
+
batch_category_checkbox = gr.Checkbox(
|
148 |
+
label="Batch by category",
|
149 |
+
value=True,
|
150 |
+
info="Group similar images together when processing"
|
151 |
+
)
|
152 |
+
|
153 |
+
# Create a container for the captioning area (initially hidden)
|
154 |
+
with gr.Column(visible=False) as captioning_area:
|
155 |
+
gr.Markdown("### Your images and captions")
|
156 |
+
|
157 |
+
# Create individual image and caption rows
|
158 |
+
image_rows = []
|
159 |
+
image_components = []
|
160 |
+
caption_components = []
|
161 |
+
|
162 |
+
for i in range(MAX_IMAGES):
|
163 |
+
with gr.Row(visible=False) as img_row:
|
164 |
+
image_rows.append(img_row)
|
165 |
+
|
166 |
+
img = gr.Image(
|
167 |
+
label=f"Image {i+1}",
|
168 |
+
type="filepath",
|
169 |
+
show_label=False,
|
170 |
+
height=200,
|
171 |
+
width=200,
|
172 |
+
scale=1
|
173 |
+
)
|
174 |
+
image_components.append(img)
|
175 |
+
|
176 |
+
caption = gr.Textbox(
|
177 |
+
label=f"Caption {i+1}",
|
178 |
+
lines=3,
|
179 |
+
scale=2
|
180 |
+
)
|
181 |
+
caption_components.append(caption)
|
182 |
+
|
183 |
+
# Add download button
|
184 |
+
download_btn = gr.Button("Download Images with Captions", variant="secondary", interactive=False)
|
185 |
+
download_output = gr.File(label="Download Zip", visible=False)
|
186 |
+
|
187 |
+
def load_captioning(files):
|
188 |
+
"""Process uploaded images and show them in the UI"""
|
189 |
+
if not files:
|
190 |
+
return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="Upload images to begin"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
|
191 |
+
|
192 |
+
# Filter to only keep image files
|
193 |
+
image_paths = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'))]
|
194 |
+
|
195 |
+
if not image_paths or len(image_paths) < 1:
|
196 |
+
gr.Warning(f"Please upload at least one image")
|
197 |
+
return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="No valid images found"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
|
198 |
+
|
199 |
+
if len(image_paths) > MAX_IMAGES:
|
200 |
+
gr.Warning(f"Only the first {MAX_IMAGES} images will be processed")
|
201 |
+
image_paths = image_paths[:MAX_IMAGES]
|
202 |
+
|
203 |
+
# Update row visibility
|
204 |
+
row_updates = []
|
205 |
+
for i in range(MAX_IMAGES):
|
206 |
+
if i < len(image_paths):
|
207 |
+
row_updates.append(gr.update(visible=True))
|
208 |
+
else:
|
209 |
+
row_updates.append(gr.update(visible=False))
|
210 |
+
|
211 |
+
return (
|
212 |
+
image_paths, # stored_image_paths
|
213 |
+
gr.update(visible=True), # captioning_area
|
214 |
+
gr.update(interactive=True), # autocaption_btn
|
215 |
+
gr.update(interactive=True), # download_btn
|
216 |
+
gr.update(visible=False), # download_output
|
217 |
+
gr.update(value=f"{len(image_paths)} images ready for captioning"), # status_text
|
218 |
+
*row_updates # image_rows
|
219 |
+
)
|
220 |
+
|
221 |
+
def update_images(image_paths):
|
222 |
+
"""Update the image components with the uploaded images"""
|
223 |
+
print(f"Updating images with paths: {image_paths}")
|
224 |
+
updates = []
|
225 |
+
for i in range(MAX_IMAGES):
|
226 |
+
if i < len(image_paths):
|
227 |
+
updates.append(gr.update(value=image_paths[i]))
|
228 |
+
else:
|
229 |
+
updates.append(gr.update(value=None))
|
230 |
+
return updates
|
231 |
+
|
232 |
+
def update_caption_labels(image_paths):
|
233 |
+
"""Update caption labels to include the image filename"""
|
234 |
+
updates = []
|
235 |
+
for i in range(MAX_IMAGES):
|
236 |
+
if i < len(image_paths):
|
237 |
+
filename = os.path.basename(image_paths[i])
|
238 |
+
updates.append(gr.update(label=filename))
|
239 |
+
else:
|
240 |
+
updates.append(gr.update(label=""))
|
241 |
+
return updates
|
242 |
+
|
243 |
+
def run_captioning(image_paths, batch_category):
|
244 |
+
"""Generate captions for the images using the CLI code path"""
|
245 |
+
if not image_paths:
|
246 |
+
return [gr.update(value="") for _ in range(MAX_IMAGES)] + [gr.update(value="No images to process")]
|
247 |
+
|
248 |
+
try:
|
249 |
+
print(f"Starting captioning for {len(image_paths)} images")
|
250 |
+
captions = process_uploaded_images(image_paths, batch_category)
|
251 |
+
print(f"Generated {len(captions)} captions")
|
252 |
+
print("Sample captions:", captions[:2]) # Debug print first two captions
|
253 |
+
|
254 |
+
gr.Info("Captioning complete!")
|
255 |
+
status = gr.update(value="✅ Captioning complete")
|
256 |
+
except Exception as e:
|
257 |
+
print(f"Error in captioning: {str(e)}")
|
258 |
+
gr.Error(f"Captioning failed: {str(e)}")
|
259 |
+
captions = [f"Error: {str(e)}" for _ in image_paths]
|
260 |
+
status = gr.update(value=f"❌ Error: {str(e)}")
|
261 |
+
|
262 |
+
# Update caption textboxes
|
263 |
+
caption_updates = []
|
264 |
+
for i in range(MAX_IMAGES):
|
265 |
+
if i < len(captions):
|
266 |
+
caption_updates.append(gr.update(value=captions[i]))
|
267 |
+
else:
|
268 |
+
caption_updates.append(gr.update(value=""))
|
269 |
+
|
270 |
+
print(f"Returning {len(caption_updates)} caption updates")
|
271 |
+
return caption_updates + [status]
|
272 |
+
|
273 |
+
def update_batch_setting(value):
|
274 |
+
"""Update the batch by category setting"""
|
275 |
+
return value
|
276 |
+
|
277 |
+
def create_zip_from_ui(image_paths, *captions_list):
|
278 |
+
"""Create a zip file from the current images and captions in the UI"""
|
279 |
+
# Filter out empty captions for non-existent images
|
280 |
+
valid_captions = [cap for i, cap in enumerate(captions_list) if i < len(image_paths) and cap]
|
281 |
+
valid_image_paths = image_paths[:len(valid_captions)]
|
282 |
+
|
283 |
+
if not valid_image_paths:
|
284 |
+
gr.Warning("No images to download")
|
285 |
+
return None
|
286 |
+
|
287 |
+
# Create zip file
|
288 |
+
zip_data = create_download_file(valid_image_paths, valid_captions)
|
289 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
290 |
+
|
291 |
+
# Create a temporary file to store the zip
|
292 |
+
temp_dir = tempfile.gettempdir()
|
293 |
+
zip_filename = f"image_captions_{timestamp}.zip"
|
294 |
+
zip_path = os.path.join(temp_dir, zip_filename)
|
295 |
+
|
296 |
+
# Write the zip data to the temporary file
|
297 |
+
with open(zip_path, "wb") as f:
|
298 |
+
f.write(zip_data)
|
299 |
+
|
300 |
+
# Return the path to the temporary file
|
301 |
+
return zip_path
|
302 |
+
|
303 |
+
# Update the upload_outputs
|
304 |
+
upload_outputs = [
|
305 |
+
stored_image_paths,
|
306 |
+
captioning_area,
|
307 |
+
autocaption_btn,
|
308 |
+
download_btn,
|
309 |
+
download_output,
|
310 |
+
status_text,
|
311 |
+
*image_rows
|
312 |
+
]
|
313 |
+
|
314 |
+
# Update both paths and images in a single flow
|
315 |
+
def process_upload(files):
|
316 |
+
# First get paths and visibility updates
|
317 |
+
image_paths, captioning_update, autocaption_update, download_btn_update, download_output_update, status_update, *row_updates = load_captioning(files)
|
318 |
+
|
319 |
+
# Then get image updates
|
320 |
+
image_updates = update_images(image_paths)
|
321 |
+
|
322 |
+
# Update caption labels with filenames
|
323 |
+
caption_label_updates = update_caption_labels(image_paths)
|
324 |
+
|
325 |
+
# Return all updates together
|
326 |
+
return [image_paths, captioning_update, autocaption_update, download_btn_update, download_output_update, status_update] + row_updates + image_updates + caption_label_updates
|
327 |
+
|
328 |
+
# Combined outputs for both functions
|
329 |
+
combined_outputs = upload_outputs + image_components + caption_components
|
330 |
+
|
331 |
+
image_upload.change(
|
332 |
+
process_upload,
|
333 |
+
inputs=[image_upload],
|
334 |
+
outputs=combined_outputs
|
335 |
+
)
|
336 |
+
|
337 |
+
# Set up batch category checkbox
|
338 |
+
batch_category_checkbox.change(
|
339 |
+
update_batch_setting,
|
340 |
+
inputs=[batch_category_checkbox],
|
341 |
+
outputs=[batch_by_category]
|
342 |
+
)
|
343 |
+
|
344 |
+
# Manage the captioning status
|
345 |
+
def on_captioning_start():
|
346 |
+
return gr.update(value="⏳ Processing captions... please wait"), gr.update(interactive=False)
|
347 |
+
|
348 |
+
def on_captioning_complete():
|
349 |
+
return gr.update(value="✅ Captioning complete"), gr.update(interactive=True)
|
350 |
+
|
351 |
+
# Set up captioning button
|
352 |
+
autocaption_btn.click(
|
353 |
+
on_captioning_start,
|
354 |
+
inputs=None,
|
355 |
+
outputs=[status_text, autocaption_btn]
|
356 |
+
).success(
|
357 |
+
run_captioning,
|
358 |
+
inputs=[stored_image_paths, batch_by_category],
|
359 |
+
outputs=caption_components + [status_text]
|
360 |
+
).success(
|
361 |
+
on_captioning_complete,
|
362 |
+
inputs=None,
|
363 |
+
outputs=[status_text, autocaption_btn]
|
364 |
+
)
|
365 |
+
|
366 |
+
# Set up download button
|
367 |
+
download_btn.click(
|
368 |
+
create_zip_from_ui,
|
369 |
+
inputs=[stored_image_paths] + caption_components,
|
370 |
+
outputs=[download_output]
|
371 |
+
).then(
|
372 |
+
lambda: gr.update(visible=True),
|
373 |
+
inputs=None,
|
374 |
+
outputs=[download_output]
|
375 |
+
)
|
376 |
+
|
377 |
+
if __name__ == "__main__":
|
378 |
+
demo.launch(share=True)
|
main.py
CHANGED
@@ -43,32 +43,16 @@ def validate_input_directory(input_dir):
|
|
43 |
sys.exit(1)
|
44 |
|
45 |
if text_files:
|
46 |
-
print("
|
47 |
-
print("The
|
48 |
-
print("The following text files were found:")
|
49 |
for file in text_files:
|
50 |
print(f" - {file}")
|
51 |
-
sys.exit(1)
|
52 |
|
53 |
-
def
|
54 |
-
"""
|
55 |
-
input_path = Path(input_dir)
|
56 |
-
output_path = Path(output_dir) if output_dir else input_path
|
57 |
-
|
58 |
-
# Validate the input directory first
|
59 |
-
validate_input_directory(input_dir)
|
60 |
-
|
61 |
-
# Create output directory if it doesn't exist
|
62 |
-
os.makedirs(output_path, exist_ok=True)
|
63 |
-
|
64 |
-
# Track the number of processed images
|
65 |
-
processed_count = 0
|
66 |
-
|
67 |
-
# Collect all images into a dictionary grouped by category
|
68 |
images_by_category = {}
|
69 |
image_paths_by_category = {}
|
70 |
|
71 |
-
# Get all files in the input directory
|
72 |
for file_path in input_path.iterdir():
|
73 |
if file_path.is_file() and is_image_file(file_path.name):
|
74 |
try:
|
@@ -87,6 +71,53 @@ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
|
|
87 |
image_paths_by_category[category].append(file_path)
|
88 |
except Exception as e:
|
89 |
print(f"Error loading {file_path.name}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
# Log the number of images found
|
92 |
total_images = sum(len(images) for images in images_by_category.values())
|
@@ -96,27 +127,11 @@ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
|
|
96 |
print("No valid images found to process.")
|
97 |
return
|
98 |
|
99 |
-
# Process images
|
100 |
if batch_images:
|
101 |
-
|
102 |
-
image_paths = image_paths_by_category[category]
|
103 |
-
try:
|
104 |
-
# Generate captions for the entire category
|
105 |
-
captions = caption_images(images)
|
106 |
-
write_captions(image_paths, captions, input_path, output_path)
|
107 |
-
processed_count += len(images)
|
108 |
-
except Exception as e:
|
109 |
-
print(f"Error generating captions for category '{category}': {e}")
|
110 |
else:
|
111 |
-
|
112 |
-
all_images = [img for imgs in images_by_category.values() for img in imgs]
|
113 |
-
all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
|
114 |
-
try:
|
115 |
-
captions = caption_images(all_images)
|
116 |
-
write_captions(all_image_paths, captions, input_path, output_path)
|
117 |
-
processed_count += len(all_images)
|
118 |
-
except Exception as e:
|
119 |
-
print(f"Error generating captions: {e}")
|
120 |
|
121 |
print(f"\nProcessing complete. {processed_count} images were captioned.")
|
122 |
|
|
|
43 |
sys.exit(1)
|
44 |
|
45 |
if text_files:
|
46 |
+
print("Warning: Text files detected in the input directory.")
|
47 |
+
print("The following text files will be overwritten:")
|
|
|
48 |
for file in text_files:
|
49 |
print(f" - {file}")
|
|
|
50 |
|
51 |
+
def collect_images_by_category(input_path):
|
52 |
+
"""Collect all valid images and group them by category."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
images_by_category = {}
|
54 |
image_paths_by_category = {}
|
55 |
|
|
|
56 |
for file_path in input_path.iterdir():
|
57 |
if file_path.is_file() and is_image_file(file_path.name):
|
58 |
try:
|
|
|
71 |
image_paths_by_category[category].append(file_path)
|
72 |
except Exception as e:
|
73 |
print(f"Error loading {file_path.name}: {e}")
|
74 |
+
|
75 |
+
return images_by_category, image_paths_by_category
|
76 |
+
|
77 |
+
def process_by_category(images_by_category, image_paths_by_category, input_path, output_path):
|
78 |
+
"""Process images in batches by category."""
|
79 |
+
processed_count = 0
|
80 |
+
|
81 |
+
for category, images in images_by_category.items():
|
82 |
+
image_paths = image_paths_by_category[category]
|
83 |
+
try:
|
84 |
+
# Generate captions for the entire category using batch mode
|
85 |
+
captions = caption_images(images, category=category, batch_mode=True)
|
86 |
+
write_captions(image_paths, captions, input_path, output_path)
|
87 |
+
processed_count += len(images)
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error generating captions for category '{category}': {e}")
|
90 |
+
|
91 |
+
return processed_count
|
92 |
+
|
93 |
+
def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path):
|
94 |
+
"""Process all images at once."""
|
95 |
+
all_images = [img for imgs in images_by_category.values() for img in imgs]
|
96 |
+
all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
|
97 |
+
processed_count = 0
|
98 |
+
|
99 |
+
try:
|
100 |
+
captions = caption_images(all_images, batch_mode=False)
|
101 |
+
write_captions(all_image_paths, captions, input_path, output_path)
|
102 |
+
processed_count += len(all_images)
|
103 |
+
except Exception as e:
|
104 |
+
print(f"Error generating captions: {e}")
|
105 |
+
|
106 |
+
return processed_count
|
107 |
+
|
108 |
+
def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
|
109 |
+
"""Process all images in the input directory and generate captions."""
|
110 |
+
input_path = Path(input_dir)
|
111 |
+
output_path = Path(output_dir) if output_dir else input_path
|
112 |
+
|
113 |
+
# Validate the input directory first
|
114 |
+
validate_input_directory(input_dir)
|
115 |
+
|
116 |
+
# Create output directory if it doesn't exist
|
117 |
+
os.makedirs(output_path, exist_ok=True)
|
118 |
+
|
119 |
+
# Collect images by category
|
120 |
+
images_by_category, image_paths_by_category = collect_images_by_category(input_path)
|
121 |
|
122 |
# Log the number of images found
|
123 |
total_images = sum(len(images) for images in images_by_category.values())
|
|
|
127 |
print("No valid images found to process.")
|
128 |
return
|
129 |
|
130 |
+
# Process images based on batch setting
|
131 |
if batch_images:
|
132 |
+
processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
else:
|
134 |
+
processed_count = process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
print(f"\nProcessing complete. {processed_count} images were captioned.")
|
137 |
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.44.1
|
2 |
+
Pillow==10.0.0
|
3 |
+
pydantic>=2.0.0
|
4 |
+
together
|
5 |
+
fastapi>=0.100.0
|