Rishi Desai commited on
Commit
dc6215b
·
1 Parent(s): ca96bd8

init dump for demo

Browse files
Files changed (4) hide show
  1. caption.py +110 -44
  2. demo.py +378 -0
  3. main.py +54 -39
  4. requirements.txt +5 -0
caption.py CHANGED
@@ -10,7 +10,15 @@ load_dotenv()
10
  def get_prompt():
11
  return """Automated Image Captioning (for LoRA Training)
12
 
13
- Role: You are an expert AI captioning system generating precise, structured descriptions for AI-generated character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.
 
 
 
 
 
 
 
 
14
 
15
  General Guidelines:
16
  1. Prioritize Consistency – Maintain uniform descriptions across all images in a dataset. Avoid introducing variation in features that should remain constant (e.g., fixed traits like eye color, hair color, or markings that are inherently part of the concept and handled during model training).
@@ -28,8 +36,8 @@ Avoid Describing These Unless Variable Across Dataset or Uncertain from Concept:
28
  - Tattoos or markings if core to the concept
29
  - Known accessories that always appear (unless outfit-specific)
30
 
31
- Updated Caption Format:
32
- tr1gger [Style], [Notable Visual Features], [Clothing], [Pose], [Expression], [Lighting], [Camera Angle]
33
 
34
  Captioning Principles:
35
  - Emphasize visual variation and context-specific details (outfit, pose, lighting, expression, camera angle).
@@ -41,14 +49,20 @@ Captioning Principles:
41
  - Avoid mentioning real or fictional identities.
42
  - Always prefix with the trigger word "tr1gger."
43
 
44
- Examples:
45
- - tr1gger photorealistic, tactical vest and gloves, standing in profile, neutral expression, overcast lighting, side profile
46
- - tr1gger 3D-rendered, hooded cloak with digital pattern, seated cross-legged, calm expression, low ambient lighting, front view
47
- - tr1gger anime-style, school uniform with blue necktie, standing with arms behind back, gentle smile, soft daylight, three-quarter view
48
- - tr1gger photorealistic, long trench coat and combat boots, walking through rain-soaked street, determined expression, dramatic shadows, low-angle view
 
 
49
  """
50
 
51
- def caption_images(images):
 
 
 
 
52
  # Convert PIL images to base64 encoded strings
53
  image_strings = []
54
  for image in images:
@@ -66,30 +80,103 @@ def caption_images(images):
66
  client = Together(api_key=api_key)
67
  captions = []
68
 
69
- # Start a separate chat session for each image
70
- for img_str in image_strings:
 
 
 
 
 
 
 
 
 
71
  messages = [
72
  {"role": "system", "content": get_prompt()},
73
- {
74
- "role": "user",
75
- "content": [
76
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}},
77
- {"type": "text", "text": "Describe this image."}
78
- ]
79
- }
80
  ]
81
 
82
- # Request caption for the image using Llama 4 Maverick
83
  response = client.chat.completions.create(
84
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
85
  messages=messages
86
  )
87
 
88
- # Extract caption from the response
89
  full_response = response.choices[0].message.content.strip()
90
- # Post-process to extract only the caption part
91
- caption = next((line for line in full_response.splitlines() if line.startswith("tr1gger")), "")
92
- captions.append(caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  return captions
95
 
@@ -100,24 +187,3 @@ def extract_captions(file_path):
100
  if line.startswith("tr1gger"):
101
  captions.append(line.strip())
102
  return captions
103
-
104
- # Example usage
105
- if __name__ == "__main__":
106
- if not os.environ.get("TOGETHER_API_KEY"):
107
- print("Please update the environment with your Together AI API key.")
108
- exit(1)
109
-
110
- # Load images
111
- image_paths = ['input/daenyrs_hd.jpg', 'input/girl_body.png']
112
- images = [Image.open(path).convert("RGB") for path in image_paths]
113
-
114
- # Generate captions
115
- captions = caption_images(images)
116
- for i, caption in enumerate(captions):
117
- print(f"Generated Caption for Image {i+1}: {caption}")
118
-
119
- # Extract captions from a file
120
- file_path = 'post_girl/multiview_0.txt'
121
- extracted_captions = extract_captions(file_path)
122
- for caption in extracted_captions:
123
- print(caption)
 
10
  def get_prompt():
11
  return """Automated Image Captioning (for LoRA Training)
12
 
13
+ Role: You are an expert AI captioning system generating precise, structured descriptions for character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.
14
+
15
+ IMPORTANT: You MUST follow these rules EXACTLY:
16
+ 1. EVERY caption MUST start with the word "tr1gger" (exactly like this, no variations)
17
+ 2. You MUST use the exact format: tr1gger [Style], [Notable Visual Features], [Clothing], [Pose], [Expression], [Lighting], [Camera Angle]
18
+ 3. DO NOT include any additional text, explanations, or formatting
19
+ 4. DO NOT use bullet points, lists, or any other formatting
20
+ 5. DO NOT include any text before or after the caption
21
+ 6. If you don't follow this format exactly, the caption will be rejected
22
 
23
  General Guidelines:
24
  1. Prioritize Consistency – Maintain uniform descriptions across all images in a dataset. Avoid introducing variation in features that should remain constant (e.g., fixed traits like eye color, hair color, or markings that are inherently part of the concept and handled during model training).
 
36
  - Tattoos or markings if core to the concept
37
  - Known accessories that always appear (unless outfit-specific)
38
 
39
+ Caption Format (MUST FOLLOW EXACTLY):
40
+ tr1gger [Style], [Notable Visual Features], [Clothing], [Pose], [Expression/Mood], [Background/Setting], [Lighting], [Camera Angle]
41
 
42
  Captioning Principles:
43
  - Emphasize visual variation and context-specific details (outfit, pose, lighting, expression, camera angle).
 
49
  - Avoid mentioning real or fictional identities.
50
  - Always prefix with the trigger word "tr1gger."
51
 
52
+ Examples (MUST FOLLOW THIS EXACT FORMAT):
53
+ tr1gger photorealistic, combat gear, tactical vest and gloves, standing in profile, neutral, empty room, overcast lighting, side profile
54
+ tr1gger 3D-rendered, digital patterns, hooded cloak, seated cross-legged, calm, meditation chamber, low ambient lighting, front view
55
+ tr1gger anime-style, school uniform with blue necktie, standing with arms behind back, gentle smile, classroom, soft daylight, three-quarter view
56
+ tr1gger photorealistic, long trench coat and combat boots, walking, determined, rain-soaked street, dramatic shadows, low-angle view
57
+
58
+ REMEMBER: Your response must be a single line starting with "tr1gger" and following the exact format above. No additional text, formatting, or explanations are allowed.
59
  """
60
 
61
+ class CaptioningError(Exception):
62
+ """Exception raised for errors in the captioning process."""
63
+ pass
64
+
65
+ def caption_images(images, category=None, batch_mode=False):
66
  # Convert PIL images to base64 encoded strings
67
  image_strings = []
68
  for image in images:
 
80
  client = Together(api_key=api_key)
81
  captions = []
82
 
83
+ # If batch_mode is True, process all images in a single API call
84
+ if batch_mode and category:
85
+ # Create a content array with all images
86
+ content = [{"type": "text", "text": f"Here is the batch of images for {category}. Please caption each image on a separate line, starting each caption with 'tr1gger'."}]
87
+
88
+ # Add all images to the content array
89
+ for i, img_str in enumerate(image_strings):
90
+ content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})
91
+ content.append({"type": "text", "text": f"Image {i+1}"})
92
+
93
+ # Send the batch request
94
  messages = [
95
  {"role": "system", "content": get_prompt()},
96
+ {"role": "user", "content": content}
 
 
 
 
 
 
97
  ]
98
 
 
99
  response = client.chat.completions.create(
100
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
101
  messages=messages
102
  )
103
 
104
+ # Parse the response to extract captions for each image
105
  full_response = response.choices[0].message.content.strip()
106
+ lines = full_response.splitlines()
107
+
108
+ # Extract captions from the response
109
+ image_count = len(image_strings)
110
+ captions = [""] * image_count # Initialize with empty strings
111
+
112
+ # Extract lines that start with or contain "tr1gger"
113
+ tr1gger_lines = [line for line in lines if "tr1gger" in line]
114
+
115
+ # Assign captions to images
116
+ for i in range(image_count):
117
+ if i < len(tr1gger_lines):
118
+ caption = tr1gger_lines[i]
119
+ # If caption contains but doesn't start with tr1gger, extract just that part
120
+ if not caption.startswith("tr1gger") and "tr1gger" in caption:
121
+ caption = caption[caption.index("tr1gger"):]
122
+ captions[i] = caption
123
+
124
+ # Check if all captions are empty or don't contain the trigger word
125
+ valid_captions = [c for c in captions if c and "tr1gger" in c]
126
+ if not valid_captions:
127
+ error_msg = "Failed to parse any valid captions from batch response. Response contained no lines with 'tr1gger'"
128
+ error_msg += f"\n\nActual response:\n{full_response}"
129
+ raise CaptioningError(error_msg)
130
+
131
+ # Check if some captions are missing
132
+ if len(valid_captions) < len(images):
133
+ missing_count = len(images) - len(valid_captions)
134
+ invalid_captions = [(i, c) for i, c in enumerate(captions) if not c or "tr1gger" not in c]
135
+ error_msg = f"Failed to parse captions for {missing_count} of {len(images)} images in batch mode"
136
+ error_msg += "\n\nMalformed captions:"
137
+ for idx, caption in invalid_captions:
138
+ error_msg += f"\nImage {idx+1}: '{caption}'"
139
+ raise CaptioningError(error_msg)
140
+ else:
141
+ # Original method: process each image separately
142
+ for img_str in image_strings:
143
+ messages = [
144
+ {"role": "system", "content": get_prompt()},
145
+ {
146
+ "role": "user",
147
+ "content": [
148
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}},
149
+ {"type": "text", "text": "Describe this image."}
150
+ ]
151
+ }
152
+ ]
153
+
154
+ # Request caption for the image using Llama 4 Maverick
155
+ response = client.chat.completions.create(
156
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
157
+ messages=messages
158
+ )
159
+
160
+ # Extract caption from the response
161
+ full_response = response.choices[0].message.content.strip()
162
+ # Post-process to extract only the caption part
163
+ caption = ""
164
+ for line in full_response.splitlines():
165
+ if "tr1gger" in line:
166
+ # If caption has a numbered prefix, extract just the caption part
167
+ if not line.startswith("tr1gger"):
168
+ caption = line[line.index("tr1gger"):]
169
+ else:
170
+ caption = line
171
+ break
172
+
173
+ # Check if caption is valid
174
+ if not caption:
175
+ error_msg = "Failed to extract a valid caption (containing 'tr1gger') from the response"
176
+ error_msg += f"\n\nActual response:\n{full_response}"
177
+ raise CaptioningError(error_msg)
178
+
179
+ captions.append(caption)
180
 
181
  return captions
182
 
 
187
  if line.startswith("tr1gger"):
188
  captions.append(line.strip())
189
  return captions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import zipfile
4
+ from io import BytesIO
5
+ import PIL.Image
6
+ import time
7
+ import tempfile
8
+ from main import process_images, collect_images_by_category, write_captions # Import the CLI functions
9
+ from dotenv import load_dotenv
10
+ from pathlib import Path
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # Maximum number of images
16
+ MAX_IMAGES = 30
17
+
18
+ def create_download_file(image_paths, captions):
19
+ """Create a zip file with images and their captions"""
20
+ zip_io = BytesIO()
21
+ with zipfile.ZipFile(zip_io, 'w') as zip_file:
22
+ for i, (image_path, caption) in enumerate(zip(image_paths, captions)):
23
+ # Get original filename without extension
24
+ base_name = os.path.splitext(os.path.basename(image_path))[0]
25
+ img_name = f"{base_name}.png"
26
+ caption_name = f"{base_name}.txt"
27
+
28
+ # Add image to zip
29
+ with open(image_path, 'rb') as img_file:
30
+ zip_file.writestr(img_name, img_file.read())
31
+
32
+ # Add caption to zip
33
+ zip_file.writestr(caption_name, caption)
34
+
35
+ return zip_io.getvalue()
36
+
37
+ def process_uploaded_images(image_paths, batch_by_category=False):
38
+ """Process uploaded images using the same code path as CLI"""
39
+ try:
40
+ print(f"Processing {len(image_paths)} images, batch_by_category={batch_by_category}")
41
+ # Create a temporary directory to store the images
42
+ with tempfile.TemporaryDirectory() as temp_dir:
43
+ # Copy images to temp directory and maintain original order
44
+ temp_image_paths = []
45
+ original_to_temp = {} # Map original paths to temp paths
46
+ for path in image_paths:
47
+ filename = os.path.basename(path)
48
+ temp_path = os.path.join(temp_dir, filename)
49
+ with open(path, 'rb') as src, open(temp_path, 'wb') as dst:
50
+ dst.write(src.read())
51
+ temp_image_paths.append(temp_path)
52
+ original_to_temp[path] = temp_path
53
+
54
+ print(f"Created {len(temp_image_paths)} temporary files")
55
+
56
+ # Convert temp_dir to Path object for collect_images_by_category
57
+ temp_dir_path = Path(temp_dir)
58
+
59
+ # Process images using the CLI code path
60
+ images_by_category, image_paths_by_category = collect_images_by_category(temp_dir_path)
61
+ print(f"Collected images into {len(images_by_category)} categories")
62
+
63
+ # Get all images and paths in the correct order
64
+ all_images = []
65
+ all_image_paths = []
66
+ for path in image_paths: # Use original order
67
+ temp_path = original_to_temp[path]
68
+ found = False
69
+ for category, paths in image_paths_by_category.items():
70
+ if temp_path in [str(p) for p in paths]: # Convert Path objects to strings for comparison
71
+ idx = [str(p) for p in paths].index(temp_path)
72
+ all_images.append(images_by_category[category][idx])
73
+ all_image_paths.append(path) # Use original path
74
+ found = True
75
+ break
76
+ if not found:
77
+ print(f"Warning: Could not find image {path} in categorized data")
78
+
79
+ print(f"Collected {len(all_images)} images in correct order")
80
+
81
+ # Process based on batch setting
82
+ if batch_by_category:
83
+ # Process each category separately
84
+ captions = [""] * len(image_paths) # Initialize with empty strings
85
+ for category, images in images_by_category.items():
86
+ category_paths = image_paths_by_category[category]
87
+ print(f"Processing category '{category}' with {len(images)} images")
88
+ # Use the same code path as CLI
89
+ from caption import caption_images
90
+ category_captions = caption_images(images, category=category, batch_mode=True)
91
+ print(f"Generated {len(category_captions)} captions for category '{category}'")
92
+ print("Category captions:", category_captions) # Debug print category captions
93
+
94
+ # Map captions back to original paths
95
+ for temp_path, caption in zip(category_paths, category_captions):
96
+ temp_path_str = str(temp_path)
97
+ for orig_path, orig_temp in original_to_temp.items():
98
+ if orig_temp == temp_path_str:
99
+ idx = image_paths.index(orig_path)
100
+ captions[idx] = caption
101
+ break
102
+ else:
103
+ # Process all images at once
104
+ from caption import caption_images
105
+ print(f"Processing all {len(all_images)} images at once")
106
+ all_captions = caption_images(all_images, batch_mode=False)
107
+ print(f"Generated {len(all_captions)} captions")
108
+ print("All captions:", all_captions) # Debug print all captions
109
+ captions = [""] * len(image_paths)
110
+ for path, caption in zip(all_image_paths, all_captions):
111
+ idx = image_paths.index(path)
112
+ captions[idx] = caption
113
+
114
+ print(f"Returning {len(captions)} captions")
115
+ print("Final captions:", captions) # Debug print final captions
116
+ return captions
117
+
118
+ except Exception as e:
119
+ print(f"Error in processing: {e}")
120
+ raise
121
+
122
+ # Main Gradio interface
123
+ with gr.Blocks() as demo:
124
+ gr.Markdown("# Image Autocaptioner")
125
+
126
+ # Store uploaded images
127
+ stored_image_paths = gr.State([])
128
+ batch_by_category = gr.State(True) # State to track if batch by category is enabled
129
+
130
+ # Upload component
131
+ with gr.Row():
132
+ with gr.Column(scale=2):
133
+ gr.Markdown("### Upload your images")
134
+ image_upload = gr.File(
135
+ file_count="multiple",
136
+ label="Drop your files here",
137
+ file_types=["image"],
138
+ type="filepath"
139
+ )
140
+
141
+ with gr.Column(scale=1):
142
+ autocaption_btn = gr.Button("Autocaption Images", variant="primary", interactive=False)
143
+ status_text = gr.Markdown("Upload images to begin", visible=True)
144
+
145
+ # Advanced settings dropdown
146
+ with gr.Accordion("Advanced", open=False):
147
+ batch_category_checkbox = gr.Checkbox(
148
+ label="Batch by category",
149
+ value=True,
150
+ info="Group similar images together when processing"
151
+ )
152
+
153
+ # Create a container for the captioning area (initially hidden)
154
+ with gr.Column(visible=False) as captioning_area:
155
+ gr.Markdown("### Your images and captions")
156
+
157
+ # Create individual image and caption rows
158
+ image_rows = []
159
+ image_components = []
160
+ caption_components = []
161
+
162
+ for i in range(MAX_IMAGES):
163
+ with gr.Row(visible=False) as img_row:
164
+ image_rows.append(img_row)
165
+
166
+ img = gr.Image(
167
+ label=f"Image {i+1}",
168
+ type="filepath",
169
+ show_label=False,
170
+ height=200,
171
+ width=200,
172
+ scale=1
173
+ )
174
+ image_components.append(img)
175
+
176
+ caption = gr.Textbox(
177
+ label=f"Caption {i+1}",
178
+ lines=3,
179
+ scale=2
180
+ )
181
+ caption_components.append(caption)
182
+
183
+ # Add download button
184
+ download_btn = gr.Button("Download Images with Captions", variant="secondary", interactive=False)
185
+ download_output = gr.File(label="Download Zip", visible=False)
186
+
187
+ def load_captioning(files):
188
+ """Process uploaded images and show them in the UI"""
189
+ if not files:
190
+ return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="Upload images to begin"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
191
+
192
+ # Filter to only keep image files
193
+ image_paths = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'))]
194
+
195
+ if not image_paths or len(image_paths) < 1:
196
+ gr.Warning(f"Please upload at least one image")
197
+ return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="No valid images found"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
198
+
199
+ if len(image_paths) > MAX_IMAGES:
200
+ gr.Warning(f"Only the first {MAX_IMAGES} images will be processed")
201
+ image_paths = image_paths[:MAX_IMAGES]
202
+
203
+ # Update row visibility
204
+ row_updates = []
205
+ for i in range(MAX_IMAGES):
206
+ if i < len(image_paths):
207
+ row_updates.append(gr.update(visible=True))
208
+ else:
209
+ row_updates.append(gr.update(visible=False))
210
+
211
+ return (
212
+ image_paths, # stored_image_paths
213
+ gr.update(visible=True), # captioning_area
214
+ gr.update(interactive=True), # autocaption_btn
215
+ gr.update(interactive=True), # download_btn
216
+ gr.update(visible=False), # download_output
217
+ gr.update(value=f"{len(image_paths)} images ready for captioning"), # status_text
218
+ *row_updates # image_rows
219
+ )
220
+
221
+ def update_images(image_paths):
222
+ """Update the image components with the uploaded images"""
223
+ print(f"Updating images with paths: {image_paths}")
224
+ updates = []
225
+ for i in range(MAX_IMAGES):
226
+ if i < len(image_paths):
227
+ updates.append(gr.update(value=image_paths[i]))
228
+ else:
229
+ updates.append(gr.update(value=None))
230
+ return updates
231
+
232
+ def update_caption_labels(image_paths):
233
+ """Update caption labels to include the image filename"""
234
+ updates = []
235
+ for i in range(MAX_IMAGES):
236
+ if i < len(image_paths):
237
+ filename = os.path.basename(image_paths[i])
238
+ updates.append(gr.update(label=filename))
239
+ else:
240
+ updates.append(gr.update(label=""))
241
+ return updates
242
+
243
+ def run_captioning(image_paths, batch_category):
244
+ """Generate captions for the images using the CLI code path"""
245
+ if not image_paths:
246
+ return [gr.update(value="") for _ in range(MAX_IMAGES)] + [gr.update(value="No images to process")]
247
+
248
+ try:
249
+ print(f"Starting captioning for {len(image_paths)} images")
250
+ captions = process_uploaded_images(image_paths, batch_category)
251
+ print(f"Generated {len(captions)} captions")
252
+ print("Sample captions:", captions[:2]) # Debug print first two captions
253
+
254
+ gr.Info("Captioning complete!")
255
+ status = gr.update(value="✅ Captioning complete")
256
+ except Exception as e:
257
+ print(f"Error in captioning: {str(e)}")
258
+ gr.Error(f"Captioning failed: {str(e)}")
259
+ captions = [f"Error: {str(e)}" for _ in image_paths]
260
+ status = gr.update(value=f"❌ Error: {str(e)}")
261
+
262
+ # Update caption textboxes
263
+ caption_updates = []
264
+ for i in range(MAX_IMAGES):
265
+ if i < len(captions):
266
+ caption_updates.append(gr.update(value=captions[i]))
267
+ else:
268
+ caption_updates.append(gr.update(value=""))
269
+
270
+ print(f"Returning {len(caption_updates)} caption updates")
271
+ return caption_updates + [status]
272
+
273
+ def update_batch_setting(value):
274
+ """Update the batch by category setting"""
275
+ return value
276
+
277
+ def create_zip_from_ui(image_paths, *captions_list):
278
+ """Create a zip file from the current images and captions in the UI"""
279
+ # Filter out empty captions for non-existent images
280
+ valid_captions = [cap for i, cap in enumerate(captions_list) if i < len(image_paths) and cap]
281
+ valid_image_paths = image_paths[:len(valid_captions)]
282
+
283
+ if not valid_image_paths:
284
+ gr.Warning("No images to download")
285
+ return None
286
+
287
+ # Create zip file
288
+ zip_data = create_download_file(valid_image_paths, valid_captions)
289
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
290
+
291
+ # Create a temporary file to store the zip
292
+ temp_dir = tempfile.gettempdir()
293
+ zip_filename = f"image_captions_{timestamp}.zip"
294
+ zip_path = os.path.join(temp_dir, zip_filename)
295
+
296
+ # Write the zip data to the temporary file
297
+ with open(zip_path, "wb") as f:
298
+ f.write(zip_data)
299
+
300
+ # Return the path to the temporary file
301
+ return zip_path
302
+
303
+ # Update the upload_outputs
304
+ upload_outputs = [
305
+ stored_image_paths,
306
+ captioning_area,
307
+ autocaption_btn,
308
+ download_btn,
309
+ download_output,
310
+ status_text,
311
+ *image_rows
312
+ ]
313
+
314
+ # Update both paths and images in a single flow
315
+ def process_upload(files):
316
+ # First get paths and visibility updates
317
+ image_paths, captioning_update, autocaption_update, download_btn_update, download_output_update, status_update, *row_updates = load_captioning(files)
318
+
319
+ # Then get image updates
320
+ image_updates = update_images(image_paths)
321
+
322
+ # Update caption labels with filenames
323
+ caption_label_updates = update_caption_labels(image_paths)
324
+
325
+ # Return all updates together
326
+ return [image_paths, captioning_update, autocaption_update, download_btn_update, download_output_update, status_update] + row_updates + image_updates + caption_label_updates
327
+
328
+ # Combined outputs for both functions
329
+ combined_outputs = upload_outputs + image_components + caption_components
330
+
331
+ image_upload.change(
332
+ process_upload,
333
+ inputs=[image_upload],
334
+ outputs=combined_outputs
335
+ )
336
+
337
+ # Set up batch category checkbox
338
+ batch_category_checkbox.change(
339
+ update_batch_setting,
340
+ inputs=[batch_category_checkbox],
341
+ outputs=[batch_by_category]
342
+ )
343
+
344
+ # Manage the captioning status
345
+ def on_captioning_start():
346
+ return gr.update(value="⏳ Processing captions... please wait"), gr.update(interactive=False)
347
+
348
+ def on_captioning_complete():
349
+ return gr.update(value="✅ Captioning complete"), gr.update(interactive=True)
350
+
351
+ # Set up captioning button
352
+ autocaption_btn.click(
353
+ on_captioning_start,
354
+ inputs=None,
355
+ outputs=[status_text, autocaption_btn]
356
+ ).success(
357
+ run_captioning,
358
+ inputs=[stored_image_paths, batch_by_category],
359
+ outputs=caption_components + [status_text]
360
+ ).success(
361
+ on_captioning_complete,
362
+ inputs=None,
363
+ outputs=[status_text, autocaption_btn]
364
+ )
365
+
366
+ # Set up download button
367
+ download_btn.click(
368
+ create_zip_from_ui,
369
+ inputs=[stored_image_paths] + caption_components,
370
+ outputs=[download_output]
371
+ ).then(
372
+ lambda: gr.update(visible=True),
373
+ inputs=None,
374
+ outputs=[download_output]
375
+ )
376
+
377
+ if __name__ == "__main__":
378
+ demo.launch(share=True)
main.py CHANGED
@@ -43,32 +43,16 @@ def validate_input_directory(input_dir):
43
  sys.exit(1)
44
 
45
  if text_files:
46
- print("Error: Text files detected in the input directory.")
47
- print("The input directory should only contain image files to prevent overwriting existing text files.")
48
- print("The following text files were found:")
49
  for file in text_files:
50
  print(f" - {file}")
51
- sys.exit(1)
52
 
53
- def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
54
- """Process all images in the input directory and generate captions."""
55
- input_path = Path(input_dir)
56
- output_path = Path(output_dir) if output_dir else input_path
57
-
58
- # Validate the input directory first
59
- validate_input_directory(input_dir)
60
-
61
- # Create output directory if it doesn't exist
62
- os.makedirs(output_path, exist_ok=True)
63
-
64
- # Track the number of processed images
65
- processed_count = 0
66
-
67
- # Collect all images into a dictionary grouped by category
68
  images_by_category = {}
69
  image_paths_by_category = {}
70
 
71
- # Get all files in the input directory
72
  for file_path in input_path.iterdir():
73
  if file_path.is_file() and is_image_file(file_path.name):
74
  try:
@@ -87,6 +71,53 @@ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
87
  image_paths_by_category[category].append(file_path)
88
  except Exception as e:
89
  print(f"Error loading {file_path.name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # Log the number of images found
92
  total_images = sum(len(images) for images in images_by_category.values())
@@ -96,27 +127,11 @@ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
96
  print("No valid images found to process.")
97
  return
98
 
99
- # Process images by category if batch_images is True
100
  if batch_images:
101
- for category, images in images_by_category.items():
102
- image_paths = image_paths_by_category[category]
103
- try:
104
- # Generate captions for the entire category
105
- captions = caption_images(images)
106
- write_captions(image_paths, captions, input_path, output_path)
107
- processed_count += len(images)
108
- except Exception as e:
109
- print(f"Error generating captions for category '{category}': {e}")
110
  else:
111
- # Process all images at once if batch_images is False
112
- all_images = [img for imgs in images_by_category.values() for img in imgs]
113
- all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
114
- try:
115
- captions = caption_images(all_images)
116
- write_captions(all_image_paths, captions, input_path, output_path)
117
- processed_count += len(all_images)
118
- except Exception as e:
119
- print(f"Error generating captions: {e}")
120
 
121
  print(f"\nProcessing complete. {processed_count} images were captioned.")
122
 
 
43
  sys.exit(1)
44
 
45
  if text_files:
46
+ print("Warning: Text files detected in the input directory.")
47
+ print("The following text files will be overwritten:")
 
48
  for file in text_files:
49
  print(f" - {file}")
 
50
 
51
+ def collect_images_by_category(input_path):
52
+ """Collect all valid images and group them by category."""
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  images_by_category = {}
54
  image_paths_by_category = {}
55
 
 
56
  for file_path in input_path.iterdir():
57
  if file_path.is_file() and is_image_file(file_path.name):
58
  try:
 
71
  image_paths_by_category[category].append(file_path)
72
  except Exception as e:
73
  print(f"Error loading {file_path.name}: {e}")
74
+
75
+ return images_by_category, image_paths_by_category
76
+
77
+ def process_by_category(images_by_category, image_paths_by_category, input_path, output_path):
78
+ """Process images in batches by category."""
79
+ processed_count = 0
80
+
81
+ for category, images in images_by_category.items():
82
+ image_paths = image_paths_by_category[category]
83
+ try:
84
+ # Generate captions for the entire category using batch mode
85
+ captions = caption_images(images, category=category, batch_mode=True)
86
+ write_captions(image_paths, captions, input_path, output_path)
87
+ processed_count += len(images)
88
+ except Exception as e:
89
+ print(f"Error generating captions for category '{category}': {e}")
90
+
91
+ return processed_count
92
+
93
+ def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path):
94
+ """Process all images at once."""
95
+ all_images = [img for imgs in images_by_category.values() for img in imgs]
96
+ all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
97
+ processed_count = 0
98
+
99
+ try:
100
+ captions = caption_images(all_images, batch_mode=False)
101
+ write_captions(all_image_paths, captions, input_path, output_path)
102
+ processed_count += len(all_images)
103
+ except Exception as e:
104
+ print(f"Error generating captions: {e}")
105
+
106
+ return processed_count
107
+
108
+ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
109
+ """Process all images in the input directory and generate captions."""
110
+ input_path = Path(input_dir)
111
+ output_path = Path(output_dir) if output_dir else input_path
112
+
113
+ # Validate the input directory first
114
+ validate_input_directory(input_dir)
115
+
116
+ # Create output directory if it doesn't exist
117
+ os.makedirs(output_path, exist_ok=True)
118
+
119
+ # Collect images by category
120
+ images_by_category, image_paths_by_category = collect_images_by_category(input_path)
121
 
122
  # Log the number of images found
123
  total_images = sum(len(images) for images in images_by_category.values())
 
127
  print("No valid images found to process.")
128
  return
129
 
130
+ # Process images based on batch setting
131
  if batch_images:
132
+ processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path)
 
 
 
 
 
 
 
 
133
  else:
134
+ processed_count = process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path)
 
 
 
 
 
 
 
 
135
 
136
  print(f"\nProcessing complete. {processed_count} images were captioned.")
137
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==4.44.1
2
+ Pillow==10.0.0
3
+ pydantic>=2.0.0
4
+ together
5
+ fastapi>=0.100.0