Spaces:
Running
Running
Rishi Desai
commited on
Commit
·
ab00f6b
1
Parent(s):
8373fd9
some clean up
Browse files- caption.py +42 -46
- main.py +25 -29
- prompt.py +18 -26
caption.py
CHANGED
@@ -2,12 +2,9 @@ import base64
|
|
2 |
import io
|
3 |
import os
|
4 |
from together import Together
|
5 |
-
from PIL import Image
|
6 |
-
from dotenv import load_dotenv
|
7 |
|
8 |
-
load_dotenv()
|
9 |
|
10 |
-
def
|
11 |
return """Automated Image Captioning (for LoRA Training)
|
12 |
|
13 |
Role: You are an expert AI captioning system generating precise, structured descriptions for character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.
|
@@ -58,10 +55,12 @@ tr1gger photorealistic, long trench coat and combat boots, walking, determined,
|
|
58 |
REMEMBER: Your response must be a single line starting with "tr1gger" and following the exact format above. No additional text, formatting, or explanations are allowed.
|
59 |
"""
|
60 |
|
|
|
61 |
class CaptioningError(Exception):
|
62 |
"""Exception raised for errors in the captioning process."""
|
63 |
pass
|
64 |
|
|
|
65 |
def images_to_base64(images):
|
66 |
"""Convert a list of PIL images to base64 encoded strings."""
|
67 |
image_strings = []
|
@@ -72,15 +71,17 @@ def images_to_base64(images):
|
|
72 |
image_strings.append(img_str)
|
73 |
return image_strings
|
74 |
|
|
|
75 |
def get_together_client():
|
76 |
"""Initialize and return the Together API client."""
|
77 |
api_key = os.environ.get("TOGETHER_API_KEY")
|
78 |
if not api_key:
|
79 |
-
raise ValueError("TOGETHER_API_KEY
|
80 |
return Together(api_key=api_key)
|
81 |
|
82 |
-
|
83 |
-
|
|
|
84 |
if "tr1gger" in line:
|
85 |
# If caption doesn't start with tr1gger but contains it, extract just that part
|
86 |
if not line.startswith("tr1gger"):
|
@@ -88,10 +89,11 @@ def extract_trigger_caption(line):
|
|
88 |
return line
|
89 |
return ""
|
90 |
|
|
|
91 |
def caption_single_image(client, img_str):
|
92 |
"""Process and caption a single image."""
|
93 |
messages = [
|
94 |
-
{"role": "system", "content":
|
95 |
{
|
96 |
"role": "user",
|
97 |
"content": [
|
@@ -100,85 +102,81 @@ def caption_single_image(client, img_str):
|
|
100 |
]
|
101 |
}
|
102 |
]
|
103 |
-
|
104 |
# Request caption for the image using Llama 4 Maverick
|
105 |
response = client.chat.completions.create(
|
106 |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
107 |
messages=messages
|
108 |
)
|
109 |
-
|
110 |
-
# Extract caption from the response
|
111 |
full_response = response.choices[0].message.content.strip()
|
112 |
-
|
113 |
-
# Look for the trigger line in the response
|
114 |
caption = ""
|
115 |
for line in full_response.splitlines():
|
116 |
-
caption =
|
117 |
if caption:
|
118 |
break
|
119 |
-
|
120 |
-
# Check if caption is valid
|
121 |
if not caption:
|
122 |
error_msg = "Failed to extract a valid caption (containing 'tr1gger') from the response"
|
123 |
error_msg += f"\n\nActual response:\n{full_response}"
|
124 |
raise CaptioningError(error_msg)
|
125 |
-
|
126 |
return caption
|
127 |
|
128 |
-
|
|
|
129 |
"""Process and caption multiple images in a single batch request."""
|
130 |
# Create a content array with all images
|
131 |
-
content = [{"type": "text",
|
132 |
-
|
133 |
-
|
|
|
134 |
for i, img_str in enumerate(image_strings):
|
135 |
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})
|
136 |
-
content.append({"type": "text", "text": f"Image {i+1}"})
|
137 |
-
|
138 |
-
# Send the batch request
|
139 |
messages = [
|
140 |
-
{"role": "system", "content":
|
141 |
{"role": "user", "content": content}
|
142 |
]
|
143 |
-
|
144 |
response = client.chat.completions.create(
|
145 |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
146 |
messages=messages
|
147 |
)
|
148 |
-
|
149 |
return process_batch_response(response, image_strings)
|
150 |
|
|
|
151 |
def process_batch_response(response, image_strings):
|
152 |
"""Process the API response from a batch request and extract captions."""
|
153 |
-
# Parse the response to extract captions for each image
|
154 |
full_response = response.choices[0].message.content.strip()
|
155 |
lines = full_response.splitlines()
|
156 |
-
|
157 |
# Extract captions from the response
|
158 |
image_count = len(image_strings)
|
159 |
-
captions = [""] * image_count
|
160 |
-
|
161 |
# Extract lines that start with or contain "tr1gger"
|
162 |
-
|
163 |
-
|
164 |
# Assign captions to images
|
165 |
for i in range(image_count):
|
166 |
-
if i < len(
|
167 |
-
caption =
|
168 |
captions[i] = caption
|
169 |
-
|
170 |
validate_batch_captions(captions, image_count, full_response)
|
171 |
return captions
|
172 |
|
|
|
173 |
def validate_batch_captions(captions, image_count, full_response):
|
174 |
"""Validate captions extracted from a batch response."""
|
175 |
# Check if all captions are empty or don't contain the trigger word
|
176 |
valid_captions = [c for c in captions if c and "tr1gger" in c]
|
177 |
if not valid_captions:
|
178 |
-
error_msg = "Failed to parse any valid captions from batch response.
|
179 |
error_msg += f"\n\nActual response:\n{full_response}"
|
180 |
raise CaptioningError(error_msg)
|
181 |
-
|
182 |
# Check if some captions are missing
|
183 |
if len(valid_captions) < image_count:
|
184 |
missing_count = image_count - len(valid_captions)
|
@@ -186,24 +184,22 @@ def validate_batch_captions(captions, image_count, full_response):
|
|
186 |
error_msg = f"Failed to parse captions for {missing_count} of {image_count} images in batch mode"
|
187 |
error_msg += "\n\nMalformed captions:"
|
188 |
for idx, caption in invalid_captions:
|
189 |
-
error_msg += f"\nImage {idx+1}: '{caption}'"
|
190 |
raise CaptioningError(error_msg)
|
191 |
|
|
|
192 |
def caption_images(images, category=None, batch_mode=False):
|
193 |
"""Caption a list of images, either individually or in batch mode."""
|
194 |
-
# Convert PIL images to base64 encoded strings
|
195 |
image_strings = images_to_base64(images)
|
196 |
-
|
197 |
-
# Initialize the API client
|
198 |
client = get_together_client()
|
199 |
-
|
200 |
-
# Process images based on the mode
|
201 |
if batch_mode and category:
|
202 |
-
return
|
203 |
else:
|
204 |
-
# Process each image individually
|
205 |
return [caption_single_image(client, img_str) for img_str in image_strings]
|
206 |
|
|
|
207 |
def extract_captions(file_path):
|
208 |
captions = []
|
209 |
with open(file_path, 'r') as file:
|
|
|
2 |
import io
|
3 |
import os
|
4 |
from together import Together
|
|
|
|
|
5 |
|
|
|
6 |
|
7 |
+
def get_system_prompt():
|
8 |
return """Automated Image Captioning (for LoRA Training)
|
9 |
|
10 |
Role: You are an expert AI captioning system generating precise, structured descriptions for character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.
|
|
|
55 |
REMEMBER: Your response must be a single line starting with "tr1gger" and following the exact format above. No additional text, formatting, or explanations are allowed.
|
56 |
"""
|
57 |
|
58 |
+
|
59 |
class CaptioningError(Exception):
|
60 |
"""Exception raised for errors in the captioning process."""
|
61 |
pass
|
62 |
|
63 |
+
|
64 |
def images_to_base64(images):
|
65 |
"""Convert a list of PIL images to base64 encoded strings."""
|
66 |
image_strings = []
|
|
|
71 |
image_strings.append(img_str)
|
72 |
return image_strings
|
73 |
|
74 |
+
|
75 |
def get_together_client():
|
76 |
"""Initialize and return the Together API client."""
|
77 |
api_key = os.environ.get("TOGETHER_API_KEY")
|
78 |
if not api_key:
|
79 |
+
raise ValueError("TOGETHER_API_KEY not set!")
|
80 |
return Together(api_key=api_key)
|
81 |
|
82 |
+
|
83 |
+
def extract_caption(line):
|
84 |
+
"""Extract caption from a line of text."""
|
85 |
if "tr1gger" in line:
|
86 |
# If caption doesn't start with tr1gger but contains it, extract just that part
|
87 |
if not line.startswith("tr1gger"):
|
|
|
89 |
return line
|
90 |
return ""
|
91 |
|
92 |
+
|
93 |
def caption_single_image(client, img_str):
|
94 |
"""Process and caption a single image."""
|
95 |
messages = [
|
96 |
+
{"role": "system", "content": get_system_prompt()},
|
97 |
{
|
98 |
"role": "user",
|
99 |
"content": [
|
|
|
102 |
]
|
103 |
}
|
104 |
]
|
105 |
+
|
106 |
# Request caption for the image using Llama 4 Maverick
|
107 |
response = client.chat.completions.create(
|
108 |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
109 |
messages=messages
|
110 |
)
|
111 |
+
|
|
|
112 |
full_response = response.choices[0].message.content.strip()
|
|
|
|
|
113 |
caption = ""
|
114 |
for line in full_response.splitlines():
|
115 |
+
caption = extract_caption(line)
|
116 |
if caption:
|
117 |
break
|
118 |
+
|
|
|
119 |
if not caption:
|
120 |
error_msg = "Failed to extract a valid caption (containing 'tr1gger') from the response"
|
121 |
error_msg += f"\n\nActual response:\n{full_response}"
|
122 |
raise CaptioningError(error_msg)
|
123 |
+
|
124 |
return caption
|
125 |
|
126 |
+
|
127 |
+
def caption_image_batch(client, image_strings, category):
|
128 |
"""Process and caption multiple images in a single batch request."""
|
129 |
# Create a content array with all images
|
130 |
+
content = [{"type": "text",
|
131 |
+
"text": f"Here is the batch of images for {category}. "
|
132 |
+
f"Caption each image on a separate line."}]
|
133 |
+
|
134 |
for i, img_str in enumerate(image_strings):
|
135 |
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})
|
136 |
+
content.append({"type": "text", "text": f"Image {i + 1}"})
|
137 |
+
|
|
|
138 |
messages = [
|
139 |
+
{"role": "system", "content": get_system_prompt()},
|
140 |
{"role": "user", "content": content}
|
141 |
]
|
|
|
142 |
response = client.chat.completions.create(
|
143 |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
144 |
messages=messages
|
145 |
)
|
|
|
146 |
return process_batch_response(response, image_strings)
|
147 |
|
148 |
+
|
149 |
def process_batch_response(response, image_strings):
|
150 |
"""Process the API response from a batch request and extract captions."""
|
|
|
151 |
full_response = response.choices[0].message.content.strip()
|
152 |
lines = full_response.splitlines()
|
153 |
+
|
154 |
# Extract captions from the response
|
155 |
image_count = len(image_strings)
|
156 |
+
captions = [""] * image_count
|
157 |
+
|
158 |
# Extract lines that start with or contain "tr1gger"
|
159 |
+
caption_lines = [line for line in lines if "tr1gger" in line]
|
160 |
+
|
161 |
# Assign captions to images
|
162 |
for i in range(image_count):
|
163 |
+
if i < len(caption_lines):
|
164 |
+
caption = extract_caption(caption_lines[i])
|
165 |
captions[i] = caption
|
166 |
+
|
167 |
validate_batch_captions(captions, image_count, full_response)
|
168 |
return captions
|
169 |
|
170 |
+
|
171 |
def validate_batch_captions(captions, image_count, full_response):
|
172 |
"""Validate captions extracted from a batch response."""
|
173 |
# Check if all captions are empty or don't contain the trigger word
|
174 |
valid_captions = [c for c in captions if c and "tr1gger" in c]
|
175 |
if not valid_captions:
|
176 |
+
error_msg = "Failed to parse any valid captions from batch response."
|
177 |
error_msg += f"\n\nActual response:\n{full_response}"
|
178 |
raise CaptioningError(error_msg)
|
179 |
+
|
180 |
# Check if some captions are missing
|
181 |
if len(valid_captions) < image_count:
|
182 |
missing_count = image_count - len(valid_captions)
|
|
|
184 |
error_msg = f"Failed to parse captions for {missing_count} of {image_count} images in batch mode"
|
185 |
error_msg += "\n\nMalformed captions:"
|
186 |
for idx, caption in invalid_captions:
|
187 |
+
error_msg += f"\nImage {idx + 1}: '{caption}'"
|
188 |
raise CaptioningError(error_msg)
|
189 |
|
190 |
+
|
191 |
def caption_images(images, category=None, batch_mode=False):
|
192 |
"""Caption a list of images, either individually or in batch mode."""
|
|
|
193 |
image_strings = images_to_base64(images)
|
194 |
+
|
|
|
195 |
client = get_together_client()
|
196 |
+
|
|
|
197 |
if batch_mode and category:
|
198 |
+
return caption_image_batch(client, image_strings, category)
|
199 |
else:
|
|
|
200 |
return [caption_single_image(client, img_str) for img_str in image_strings]
|
201 |
|
202 |
+
|
203 |
def extract_captions(file_path):
|
204 |
captions = []
|
205 |
with open(file_path, 'r') as file:
|
main.py
CHANGED
@@ -6,34 +6,38 @@ from pathlib import Path
|
|
6 |
from PIL import Image
|
7 |
from caption import caption_images
|
8 |
|
|
|
9 |
def is_image_file(filename):
|
10 |
"""Check if a file is an allowed image type."""
|
11 |
allowed_extensions = ['.png', '.jpg', '.jpeg', '.webp']
|
12 |
return any(filename.lower().endswith(ext) for ext in allowed_extensions)
|
13 |
|
|
|
14 |
def is_unsupported_image(filename):
|
15 |
"""Check if a file is an image but not of an allowed type."""
|
16 |
unsupported_extensions = ['.bmp', '.gif', '.tiff', '.tif', '.ico', '.svg']
|
17 |
return any(filename.lower().endswith(ext) for ext in unsupported_extensions)
|
18 |
|
|
|
19 |
def is_text_file(filename):
|
20 |
"""Check if a file is a text file."""
|
21 |
return filename.lower().endswith('.txt')
|
22 |
|
|
|
23 |
def validate_input_directory(input_dir):
|
24 |
"""Validate that the input directory only contains allowed image formats."""
|
25 |
input_path = Path(input_dir)
|
26 |
-
|
27 |
unsupported_files = []
|
28 |
text_files = []
|
29 |
-
|
30 |
for file_path in input_path.iterdir():
|
31 |
if file_path.is_file():
|
32 |
if is_unsupported_image(file_path.name):
|
33 |
unsupported_files.append(file_path.name)
|
34 |
elif is_text_file(file_path.name):
|
35 |
text_files.append(file_path.name)
|
36 |
-
|
37 |
if unsupported_files:
|
38 |
print("Error: Unsupported image formats detected.")
|
39 |
print("Only .png, .jpg, .jpeg, and .webp files are allowed.")
|
@@ -41,13 +45,14 @@ def validate_input_directory(input_dir):
|
|
41 |
for file in unsupported_files:
|
42 |
print(f" - {file}")
|
43 |
sys.exit(1)
|
44 |
-
|
45 |
if text_files:
|
46 |
print("Warning: Text files detected in the input directory.")
|
47 |
print("The following text files will be overwritten:")
|
48 |
for file in text_files:
|
49 |
print(f" - {file}")
|
50 |
|
|
|
51 |
def collect_images_by_category(input_path):
|
52 |
"""Collect all valid images and group them by category."""
|
53 |
images_by_category = {}
|
@@ -56,28 +61,27 @@ def collect_images_by_category(input_path):
|
|
56 |
for file_path in input_path.iterdir():
|
57 |
if file_path.is_file() and is_image_file(file_path.name):
|
58 |
try:
|
59 |
-
# Load the image
|
60 |
image = Image.open(file_path).convert("RGB")
|
61 |
-
|
62 |
# Determine the category from the filename
|
63 |
category = file_path.stem.rsplit('_', 1)[0]
|
64 |
-
|
65 |
# Add image to the appropriate category
|
66 |
if category not in images_by_category:
|
67 |
images_by_category[category] = []
|
68 |
image_paths_by_category[category] = []
|
69 |
-
|
70 |
images_by_category[category].append(image)
|
71 |
image_paths_by_category[category].append(file_path)
|
72 |
except Exception as e:
|
73 |
print(f"Error loading {file_path.name}: {e}")
|
74 |
-
|
75 |
return images_by_category, image_paths_by_category
|
76 |
|
|
|
77 |
def process_by_category(images_by_category, image_paths_by_category, input_path, output_path):
|
78 |
"""Process images in batches by category."""
|
79 |
processed_count = 0
|
80 |
-
|
81 |
for category, images in images_by_category.items():
|
82 |
image_paths = image_paths_by_category[category]
|
83 |
try:
|
@@ -87,35 +91,31 @@ def process_by_category(images_by_category, image_paths_by_category, input_path,
|
|
87 |
processed_count += len(images)
|
88 |
except Exception as e:
|
89 |
print(f"Error generating captions for category '{category}': {e}")
|
90 |
-
|
91 |
return processed_count
|
92 |
|
|
|
93 |
def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path):
|
94 |
"""Process all images at once."""
|
95 |
all_images = [img for imgs in images_by_category.values() for img in imgs]
|
96 |
all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
|
97 |
processed_count = 0
|
98 |
-
|
99 |
try:
|
100 |
captions = caption_images(all_images, batch_mode=False)
|
101 |
write_captions(all_image_paths, captions, input_path, output_path)
|
102 |
processed_count += len(all_images)
|
103 |
except Exception as e:
|
104 |
print(f"Error generating captions: {e}")
|
105 |
-
|
106 |
return processed_count
|
107 |
|
|
|
108 |
def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
|
109 |
"""Process all images in the input directory and generate captions."""
|
110 |
input_path = Path(input_dir)
|
111 |
output_path = Path(output_dir) if output_dir else input_path
|
112 |
-
|
113 |
-
# Validate the input directory first
|
114 |
validate_input_directory(input_dir)
|
115 |
-
|
116 |
-
# Create output directory if it doesn't exist
|
117 |
os.makedirs(output_path, exist_ok=True)
|
118 |
-
|
119 |
# Collect images by category
|
120 |
images_by_category, image_paths_by_category = collect_images_by_category(input_path)
|
121 |
|
@@ -127,7 +127,6 @@ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
|
|
127 |
print("No valid images found to process.")
|
128 |
return
|
129 |
|
130 |
-
# Process images based on batch setting
|
131 |
if batch_images:
|
132 |
processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path)
|
133 |
else:
|
@@ -135,6 +134,7 @@ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
|
|
135 |
|
136 |
print(f"\nProcessing complete. {processed_count} images were captioned.")
|
137 |
|
|
|
138 |
def write_captions(image_paths, captions, input_path, output_path):
|
139 |
"""Helper function to write captions to files."""
|
140 |
for file_path, caption in zip(image_paths, captions):
|
@@ -143,37 +143,33 @@ def write_captions(image_paths, captions, input_path, output_path):
|
|
143 |
caption_filename = file_path.stem + ".txt"
|
144 |
caption_path = input_path / caption_filename
|
145 |
|
146 |
-
# Write caption to file
|
147 |
with open(caption_path, 'w', encoding='utf-8') as f:
|
148 |
f.write(caption)
|
149 |
|
150 |
# If output directory is different from input, copy files
|
151 |
if output_path != input_path:
|
152 |
-
# Copy image to output directory
|
153 |
shutil.copy2(file_path, output_path / file_path.name)
|
154 |
-
# Copy caption to output directory
|
155 |
shutil.copy2(caption_path, output_path / caption_filename)
|
156 |
-
|
157 |
print(f"Processed {file_path.name} → {caption_filename}")
|
158 |
except Exception as e:
|
159 |
print(f"Error processing {file_path.name}: {e}")
|
160 |
|
|
|
161 |
def main():
|
162 |
parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
|
163 |
parser.add_argument('--input', type=str, required=True, help='Directory containing images')
|
164 |
parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
|
165 |
parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
|
166 |
parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')
|
167 |
-
|
168 |
args = parser.parse_args()
|
169 |
-
|
170 |
-
# Validate input directory
|
171 |
if not os.path.isdir(args.input):
|
172 |
print(f"Error: Input directory '{args.input}' does not exist.")
|
173 |
return
|
174 |
-
|
175 |
-
# Process images
|
176 |
process_images(args.input, args.output, args.fix_outfit, args.batch_images)
|
177 |
|
|
|
178 |
if __name__ == "__main__":
|
179 |
-
main()
|
|
|
6 |
from PIL import Image
|
7 |
from caption import caption_images
|
8 |
|
9 |
+
|
10 |
def is_image_file(filename):
|
11 |
"""Check if a file is an allowed image type."""
|
12 |
allowed_extensions = ['.png', '.jpg', '.jpeg', '.webp']
|
13 |
return any(filename.lower().endswith(ext) for ext in allowed_extensions)
|
14 |
|
15 |
+
|
16 |
def is_unsupported_image(filename):
|
17 |
"""Check if a file is an image but not of an allowed type."""
|
18 |
unsupported_extensions = ['.bmp', '.gif', '.tiff', '.tif', '.ico', '.svg']
|
19 |
return any(filename.lower().endswith(ext) for ext in unsupported_extensions)
|
20 |
|
21 |
+
|
22 |
def is_text_file(filename):
|
23 |
"""Check if a file is a text file."""
|
24 |
return filename.lower().endswith('.txt')
|
25 |
|
26 |
+
|
27 |
def validate_input_directory(input_dir):
|
28 |
"""Validate that the input directory only contains allowed image formats."""
|
29 |
input_path = Path(input_dir)
|
30 |
+
|
31 |
unsupported_files = []
|
32 |
text_files = []
|
33 |
+
|
34 |
for file_path in input_path.iterdir():
|
35 |
if file_path.is_file():
|
36 |
if is_unsupported_image(file_path.name):
|
37 |
unsupported_files.append(file_path.name)
|
38 |
elif is_text_file(file_path.name):
|
39 |
text_files.append(file_path.name)
|
40 |
+
|
41 |
if unsupported_files:
|
42 |
print("Error: Unsupported image formats detected.")
|
43 |
print("Only .png, .jpg, .jpeg, and .webp files are allowed.")
|
|
|
45 |
for file in unsupported_files:
|
46 |
print(f" - {file}")
|
47 |
sys.exit(1)
|
48 |
+
|
49 |
if text_files:
|
50 |
print("Warning: Text files detected in the input directory.")
|
51 |
print("The following text files will be overwritten:")
|
52 |
for file in text_files:
|
53 |
print(f" - {file}")
|
54 |
|
55 |
+
|
56 |
def collect_images_by_category(input_path):
|
57 |
"""Collect all valid images and group them by category."""
|
58 |
images_by_category = {}
|
|
|
61 |
for file_path in input_path.iterdir():
|
62 |
if file_path.is_file() and is_image_file(file_path.name):
|
63 |
try:
|
|
|
64 |
image = Image.open(file_path).convert("RGB")
|
65 |
+
|
66 |
# Determine the category from the filename
|
67 |
category = file_path.stem.rsplit('_', 1)[0]
|
68 |
+
|
69 |
# Add image to the appropriate category
|
70 |
if category not in images_by_category:
|
71 |
images_by_category[category] = []
|
72 |
image_paths_by_category[category] = []
|
73 |
+
|
74 |
images_by_category[category].append(image)
|
75 |
image_paths_by_category[category].append(file_path)
|
76 |
except Exception as e:
|
77 |
print(f"Error loading {file_path.name}: {e}")
|
78 |
+
|
79 |
return images_by_category, image_paths_by_category
|
80 |
|
81 |
+
|
82 |
def process_by_category(images_by_category, image_paths_by_category, input_path, output_path):
|
83 |
"""Process images in batches by category."""
|
84 |
processed_count = 0
|
|
|
85 |
for category, images in images_by_category.items():
|
86 |
image_paths = image_paths_by_category[category]
|
87 |
try:
|
|
|
91 |
processed_count += len(images)
|
92 |
except Exception as e:
|
93 |
print(f"Error generating captions for category '{category}': {e}")
|
|
|
94 |
return processed_count
|
95 |
|
96 |
+
|
97 |
def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path):
|
98 |
"""Process all images at once."""
|
99 |
all_images = [img for imgs in images_by_category.values() for img in imgs]
|
100 |
all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
|
101 |
processed_count = 0
|
|
|
102 |
try:
|
103 |
captions = caption_images(all_images, batch_mode=False)
|
104 |
write_captions(all_image_paths, captions, input_path, output_path)
|
105 |
processed_count += len(all_images)
|
106 |
except Exception as e:
|
107 |
print(f"Error generating captions: {e}")
|
|
|
108 |
return processed_count
|
109 |
|
110 |
+
|
111 |
def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
|
112 |
"""Process all images in the input directory and generate captions."""
|
113 |
input_path = Path(input_dir)
|
114 |
output_path = Path(output_dir) if output_dir else input_path
|
115 |
+
|
|
|
116 |
validate_input_directory(input_dir)
|
|
|
|
|
117 |
os.makedirs(output_path, exist_ok=True)
|
118 |
+
|
119 |
# Collect images by category
|
120 |
images_by_category, image_paths_by_category = collect_images_by_category(input_path)
|
121 |
|
|
|
127 |
print("No valid images found to process.")
|
128 |
return
|
129 |
|
|
|
130 |
if batch_images:
|
131 |
processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path)
|
132 |
else:
|
|
|
134 |
|
135 |
print(f"\nProcessing complete. {processed_count} images were captioned.")
|
136 |
|
137 |
+
|
138 |
def write_captions(image_paths, captions, input_path, output_path):
|
139 |
"""Helper function to write captions to files."""
|
140 |
for file_path, caption in zip(image_paths, captions):
|
|
|
143 |
caption_filename = file_path.stem + ".txt"
|
144 |
caption_path = input_path / caption_filename
|
145 |
|
|
|
146 |
with open(caption_path, 'w', encoding='utf-8') as f:
|
147 |
f.write(caption)
|
148 |
|
149 |
# If output directory is different from input, copy files
|
150 |
if output_path != input_path:
|
|
|
151 |
shutil.copy2(file_path, output_path / file_path.name)
|
|
|
152 |
shutil.copy2(caption_path, output_path / caption_filename)
|
|
|
153 |
print(f"Processed {file_path.name} → {caption_filename}")
|
154 |
except Exception as e:
|
155 |
print(f"Error processing {file_path.name}: {e}")
|
156 |
|
157 |
+
|
158 |
def main():
|
159 |
parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
|
160 |
parser.add_argument('--input', type=str, required=True, help='Directory containing images')
|
161 |
parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
|
162 |
parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
|
163 |
parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')
|
164 |
+
|
165 |
args = parser.parse_args()
|
166 |
+
|
|
|
167 |
if not os.path.isdir(args.input):
|
168 |
print(f"Error: Input directory '{args.input}' does not exist.")
|
169 |
return
|
170 |
+
|
|
|
171 |
process_images(args.input, args.output, args.fix_outfit, args.batch_images)
|
172 |
|
173 |
+
|
174 |
if __name__ == "__main__":
|
175 |
+
main()
|
prompt.py
CHANGED
@@ -1,23 +1,18 @@
|
|
1 |
import os
|
2 |
import argparse
|
3 |
from pathlib import Path
|
4 |
-
from caption import
|
|
|
5 |
|
6 |
def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
|
7 |
-
"""
|
8 |
-
Optimize a user prompt to follow the same format as training captions.
|
9 |
|
10 |
Args:
|
11 |
-
user_prompt (str): The simple user prompt to optimize
|
12 |
captions_dir (str, optional): Directory containing caption .txt files
|
13 |
captions_list (list, optional): List of captions to use instead of loading from files
|
14 |
-
|
15 |
-
Returns:
|
16 |
-
str: The optimized prompt following the training format
|
17 |
"""
|
18 |
-
# Get captions either from directory or provided list
|
19 |
all_captions = []
|
20 |
-
|
21 |
if captions_list:
|
22 |
all_captions = captions_list
|
23 |
elif captions_dir:
|
@@ -26,19 +21,18 @@ def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
|
|
26 |
for file_path in captions_path.glob("*.txt"):
|
27 |
captions = extract_captions(file_path)
|
28 |
all_captions.extend(captions)
|
29 |
-
|
30 |
if not all_captions:
|
31 |
-
raise ValueError("
|
32 |
-
|
33 |
# Concatenate all captions with newlines
|
34 |
captions_text = "\n".join(all_captions)
|
35 |
-
|
36 |
client = get_together_client()
|
37 |
-
|
38 |
messages = [
|
39 |
-
{"role": "system", "content":
|
40 |
{
|
41 |
-
"role": "user",
|
42 |
"content": (
|
43 |
f"These are all of the captions used to train the LoRA:\n\n"
|
44 |
f"{captions_text}\n\n"
|
@@ -47,36 +41,34 @@ def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
|
|
47 |
)
|
48 |
}
|
49 |
]
|
50 |
-
|
51 |
response = client.chat.completions.create(
|
52 |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
53 |
messages=messages
|
54 |
)
|
55 |
-
|
56 |
optimized_prompt = response.choices[0].message.content.strip()
|
57 |
return optimized_prompt
|
58 |
|
|
|
59 |
def main():
|
60 |
parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.')
|
61 |
parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize')
|
62 |
-
parser.add_argument('--captions', type=str, help='Directory containing caption .txt files')
|
63 |
-
|
64 |
args = parser.parse_args()
|
65 |
-
|
66 |
-
if not args.captions:
|
67 |
-
print("Error: --captions is required.")
|
68 |
-
return
|
69 |
if not os.path.isdir(args.captions):
|
70 |
print(f"Error: Captions directory '{args.captions}' does not exist.")
|
71 |
return
|
72 |
-
|
73 |
try:
|
74 |
optimized_prompt = optimize_prompt(args.prompt, args.captions)
|
75 |
print("\nOptimized Prompt:")
|
76 |
print(optimized_prompt)
|
77 |
-
|
78 |
except Exception as e:
|
79 |
print(f"Error optimizing prompt: {e}")
|
80 |
|
|
|
81 |
if __name__ == "__main__":
|
82 |
main()
|
|
|
1 |
import os
|
2 |
import argparse
|
3 |
from pathlib import Path
|
4 |
+
from caption import get_system_prompt, get_together_client, extract_captions
|
5 |
+
|
6 |
|
7 |
def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
|
8 |
+
"""Optimize a user prompt to follow the same format as training captions.
|
|
|
9 |
|
10 |
Args:
|
11 |
+
user_prompt (str): The simple user prompt to optimize
|
12 |
captions_dir (str, optional): Directory containing caption .txt files
|
13 |
captions_list (list, optional): List of captions to use instead of loading from files
|
|
|
|
|
|
|
14 |
"""
|
|
|
15 |
all_captions = []
|
|
|
16 |
if captions_list:
|
17 |
all_captions = captions_list
|
18 |
elif captions_dir:
|
|
|
21 |
for file_path in captions_path.glob("*.txt"):
|
22 |
captions = extract_captions(file_path)
|
23 |
all_captions.extend(captions)
|
24 |
+
|
25 |
if not all_captions:
|
26 |
+
raise ValueError("Please provide either caption files or a list of captions!")
|
27 |
+
|
28 |
# Concatenate all captions with newlines
|
29 |
captions_text = "\n".join(all_captions)
|
30 |
+
|
31 |
client = get_together_client()
|
|
|
32 |
messages = [
|
33 |
+
{"role": "system", "content": get_system_prompt()},
|
34 |
{
|
35 |
+
"role": "user",
|
36 |
"content": (
|
37 |
f"These are all of the captions used to train the LoRA:\n\n"
|
38 |
f"{captions_text}\n\n"
|
|
|
41 |
)
|
42 |
}
|
43 |
]
|
44 |
+
|
45 |
response = client.chat.completions.create(
|
46 |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
47 |
messages=messages
|
48 |
)
|
49 |
+
|
50 |
optimized_prompt = response.choices[0].message.content.strip()
|
51 |
return optimized_prompt
|
52 |
|
53 |
+
|
54 |
def main():
|
55 |
parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.')
|
56 |
parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize')
|
57 |
+
parser.add_argument('--captions', type=str, required=True,help='Directory containing caption .txt files')
|
58 |
+
|
59 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
60 |
if not os.path.isdir(args.captions):
|
61 |
print(f"Error: Captions directory '{args.captions}' does not exist.")
|
62 |
return
|
63 |
+
|
64 |
try:
|
65 |
optimized_prompt = optimize_prompt(args.prompt, args.captions)
|
66 |
print("\nOptimized Prompt:")
|
67 |
print(optimized_prompt)
|
68 |
+
|
69 |
except Exception as e:
|
70 |
print(f"Error optimizing prompt: {e}")
|
71 |
|
72 |
+
|
73 |
if __name__ == "__main__":
|
74 |
main()
|