Rishi Desai commited on
Commit
49415e1
·
1 Parent(s): 5d95766

reworking the demo

Browse files
Files changed (2) hide show
  1. caption.py +2 -1
  2. demo.py +440 -238
caption.py CHANGED
@@ -3,6 +3,7 @@ import io
3
  import os
4
  from together import Together
5
 
 
6
  TRIGGER_WORD = "tr1gger"
7
 
8
  def get_system_prompt():
@@ -140,7 +141,7 @@ def caption_image_batch(client, image_strings, category):
140
  {"role": "user", "content": content}
141
  ]
142
  response = client.chat.completions.create(
143
- model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
144
  messages=messages
145
  )
146
  return process_batch_response(response, image_strings)
 
3
  import os
4
  from together import Together
5
 
6
+ MODEL = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
7
  TRIGGER_WORD = "tr1gger"
8
 
9
  def get_system_prompt():
 
141
  {"role": "user", "content": content}
142
  ]
143
  response = client.chat.completions.create(
144
+ model=MODEL,
145
  messages=messages
146
  )
147
  return process_batch_response(response, image_strings)
demo.py CHANGED
@@ -4,12 +4,16 @@ import zipfile
4
  from io import BytesIO
5
  import time
6
  import tempfile
7
- from main import collect_images_by_category
8
  from pathlib import Path
9
  from caption import caption_images
 
 
10
  # Maximum number of images
11
  MAX_IMAGES = 30
12
 
 
 
13
  def create_download_file(image_paths, captions):
14
  """Create a zip file with images and their captions"""
15
  zip_io = BytesIO()
@@ -29,129 +33,362 @@ def create_download_file(image_paths, captions):
29
 
30
  return zip_io.getvalue()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def process_uploaded_images(image_paths, batch_by_category=False):
33
  """Process uploaded images using the same code path as CLI"""
34
  try:
 
 
35
  print(f"Processing {len(image_paths)} images, batch_by_category={batch_by_category}")
36
- # Create a temporary directory to store the images
 
37
  with tempfile.TemporaryDirectory() as temp_dir:
38
- # Copy images to temp directory and maintain original order
39
- temp_image_paths = []
40
- original_to_temp = {} # Map original paths to temp paths
41
- for path in image_paths:
42
- filename = os.path.basename(path)
43
- temp_path = os.path.join(temp_dir, filename)
44
- with open(path, 'rb') as src, open(temp_path, 'wb') as dst:
45
- dst.write(src.read())
46
- temp_image_paths.append(temp_path)
47
- original_to_temp[path] = temp_path
48
-
49
- print(f"Created {len(temp_image_paths)} temporary files")
50
 
51
- # Convert temp_dir to Path object for collect_images_by_category
52
  temp_dir_path = Path(temp_dir)
53
 
54
- # Process images using the CLI code path
 
 
 
 
 
55
  images_by_category, image_paths_by_category = collect_images_by_category(temp_dir_path)
 
 
56
  print(f"Collected images into {len(images_by_category)} categories")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Get all images and paths in the correct order
59
  all_images = []
60
  all_image_paths = []
61
- for path in image_paths: # Use original order
62
- temp_path = original_to_temp[path]
63
- found = False
64
- for category, paths in image_paths_by_category.items():
65
- if temp_path in [str(p) for p in paths]: # Convert Path objects to strings for comparison
66
- idx = [str(p) for p in paths].index(temp_path)
67
- all_images.append(images_by_category[category][idx])
68
- all_image_paths.append(path) # Use original path
69
- found = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  break
71
- if not found:
72
- print(f"Warning: Could not find image {path} in categorized data")
73
 
74
- print(f"Collected {len(all_images)} images in correct order")
75
 
76
  # Process based on batch setting
77
- if batch_by_category:
78
- # Process each category separately
79
- captions = [""] * len(image_paths) # Initialize with empty strings
80
- for category, images in images_by_category.items():
81
- category_paths = image_paths_by_category[category]
82
- print(f"Processing category '{category}' with {len(images)} images")
83
- # Use the same code path as CLI
84
- category_captions = caption_images(images, category=category, batch_mode=True)
85
- print(f"Generated {len(category_captions)} captions for category '{category}'")
86
- print("Category captions:", category_captions) # Debug print category captions
87
-
88
- # Map captions back to original paths
89
- for temp_path, caption in zip(category_paths, category_captions):
90
- temp_path_str = str(temp_path)
91
- for orig_path, orig_temp in original_to_temp.items():
92
- if orig_temp == temp_path_str:
93
- idx = image_paths.index(orig_path)
94
- captions[idx] = caption
95
- break
96
  else:
97
- print(f"Processing all {len(all_images)} images at once")
98
- all_captions = caption_images(all_images, batch_mode=False)
99
- print(f"Generated {len(all_captions)} captions")
100
- print("All captions:", all_captions) # Debug print all captions
101
- captions = [""] * len(image_paths)
102
- for path, caption in zip(all_image_paths, all_captions):
103
- idx = image_paths.index(path)
104
- captions[idx] = caption
105
 
106
  print(f"Returning {len(captions)} captions")
107
- print("Final captions:", captions) # Debug print final captions
108
  return captions
109
 
110
  except Exception as e:
111
  print(f"Error in processing: {e}")
112
  raise
113
 
114
- # Main Gradio interface
115
- with gr.Blocks() as demo:
116
- gr.Markdown("# Image Auto-captioner for LoRA Training")
117
-
118
- # Store uploaded images
119
- stored_image_paths = gr.State([])
120
- batch_by_category = gr.State(False) # State to track if batch by category is enabled
121
-
122
- # Create a two-column layout for the entire interface
123
- with gr.Row():
124
- # Left column for images/upload
125
- with gr.Column(scale=1, elem_id="left-column"):
126
- # Upload area
127
- gr.Markdown("### Upload your images", elem_id="upload-heading")
128
- gr.Markdown("Only .png, .jpg, .jpeg, and .webp files are supported", elem_id="file-types-info", elem_classes="file-types-info")
129
- image_upload = gr.File(
130
- file_count="multiple",
131
- label="Drop your files here",
132
- file_types=["image"],
133
- type="filepath",
134
- height=220,
135
- elem_classes="file-upload-container",
136
- )
137
-
138
- # Right column for configuration and captions
139
- with gr.Column(scale=1.5, elem_id="right-column"):
140
- # Configuration area
141
- gr.Markdown("### Configuration")
142
- batch_category_checkbox = gr.Checkbox(
143
- label="Batch by category",
144
- value=False,
145
- info="Caption similar images together"
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- caption_btn = gr.Button("Caption Images", variant="primary", interactive=False)
149
- download_btn = gr.Button("Download Images + Captions", variant="secondary", interactive=False)
150
- download_output = gr.File(label="Download Zip", visible=False)
151
- status_text = gr.Markdown("Upload images to begin", visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # Add unified CSS for the layout
154
- gr.HTML("""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  <style>
156
  /* Unified styling for the two-column layout */
157
  #left-column, #right-column {
@@ -218,10 +455,68 @@ with gr.Blocks() as demo:
218
  .download-section {
219
  margin-top: 10px;
220
  }
 
 
 
 
 
 
 
 
 
 
 
221
  </style>
222
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- # Create a container for the captioning area (initially hidden)
 
 
 
225
  with gr.Column(visible=False) as captioning_area:
226
  # Replace the single heading with a row containing two headings
227
  with gr.Row():
@@ -260,123 +555,15 @@ with gr.Blocks() as demo:
260
  )
261
  caption_components.append(caption)
262
 
263
- def load_captioning(files):
264
- """Process uploaded images and show them in the UI"""
265
- if not files:
266
- return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="Upload images to begin"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
267
-
268
- # Filter to only keep image files
269
- image_paths = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'))]
270
-
271
- if not image_paths or len(image_paths) < 1:
272
- gr.Warning(f"Please upload at least one image")
273
- return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="No valid images found"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
274
-
275
- if len(image_paths) > MAX_IMAGES:
276
- gr.Warning(f"Only the first {MAX_IMAGES} images will be processed")
277
- image_paths = image_paths[:MAX_IMAGES]
278
-
279
- # Update row visibility
280
- row_updates = []
281
- for i in range(MAX_IMAGES):
282
- if i < len(image_paths):
283
- row_updates.append(gr.update(visible=True))
284
- else:
285
- row_updates.append(gr.update(visible=False))
286
-
287
- return (
288
- image_paths, # stored_image_paths
289
- gr.update(visible=True), # captioning_area
290
- gr.update(interactive=True), # caption_btn
291
- gr.update(interactive=False), # download_btn - initially disabled until captioning is done
292
- gr.update(visible=False), # download_output
293
- gr.update(value=f"{len(image_paths)} images ready for captioning"), # status_text
294
- *row_updates # image_rows
295
- )
296
-
297
- def update_images(image_paths):
298
- """Update the image components with the uploaded images"""
299
- print(f"Updating images with paths: {image_paths}")
300
- updates = []
301
- for i in range(MAX_IMAGES):
302
- if i < len(image_paths):
303
- updates.append(gr.update(value=image_paths[i]))
304
- else:
305
- updates.append(gr.update(value=None))
306
- return updates
307
-
308
- def update_caption_labels(image_paths):
309
- """Update caption labels to include the image filename"""
310
- updates = []
311
- for i in range(MAX_IMAGES):
312
- if i < len(image_paths):
313
- filename = os.path.basename(image_paths[i])
314
- updates.append(gr.update(label=filename))
315
- else:
316
- updates.append(gr.update(label=""))
317
- return updates
318
-
319
- def run_captioning(image_paths, batch_category):
320
- """Generate captions for the images using the CLI code path"""
321
- if not image_paths:
322
- return [gr.update(value="") for _ in range(MAX_IMAGES)] + [gr.update(value="No images to process")]
323
-
324
- try:
325
- print(f"Starting captioning for {len(image_paths)} images")
326
- captions = process_uploaded_images(image_paths, batch_category)
327
- print(f"Generated {len(captions)} captions")
328
- print("Sample captions:", captions[:2]) # Debug print first two captions
329
-
330
- gr.Info("Captioning complete!")
331
- status = gr.update(value="✅ Captioning complete")
332
- except Exception as e:
333
- print(f"Error in captioning: {str(e)}")
334
- gr.Error(f"Captioning failed: {str(e)}")
335
- captions = [f"Error: {str(e)}" for _ in image_paths]
336
- status = gr.update(value=f"❌ Error: {str(e)}")
337
-
338
- # Update caption textboxes
339
- caption_updates = []
340
- for i in range(MAX_IMAGES):
341
- if i < len(captions):
342
- caption_updates.append(gr.update(value=captions[i]))
343
- else:
344
- caption_updates.append(gr.update(value=""))
345
-
346
- print(f"Returning {len(caption_updates)} caption updates")
347
- return caption_updates + [status]
348
-
349
- def update_batch_setting(value):
350
- """Update the batch by category setting"""
351
- return value
352
-
353
- def create_zip_from_ui(image_paths, *captions_list):
354
- """Create a zip file from the current images and captions in the UI"""
355
- # Filter out empty captions for non-existent images
356
- valid_captions = [cap for i, cap in enumerate(captions_list) if i < len(image_paths) and cap]
357
- valid_image_paths = image_paths[:len(valid_captions)]
358
-
359
- if not valid_image_paths:
360
- gr.Warning("No images to download")
361
- return None
362
-
363
- # Create zip file
364
- zip_data = create_download_file(valid_image_paths, valid_captions)
365
- timestamp = time.strftime("%Y%m%d_%H%M%S")
366
-
367
- # Create a temporary file to store the zip
368
- temp_dir = tempfile.gettempdir()
369
- zip_filename = f"image_captions_{timestamp}.zip"
370
- zip_path = os.path.join(temp_dir, zip_filename)
371
-
372
- # Write the zip data to the temporary file
373
- with open(zip_path, "wb") as f:
374
- f.write(zip_data)
375
-
376
- # Return the path to the temporary file
377
- return zip_path
378
-
379
- # Update the upload_outputs
380
  upload_outputs = [
381
  stored_image_paths,
382
  captioning_area,
@@ -387,25 +574,11 @@ with gr.Blocks() as demo:
387
  *image_rows
388
  ]
389
 
390
- # Update both paths and images in a single flow
391
- def process_upload(files):
392
- # First get paths and visibility updates
393
- image_paths, captioning_update, caption_btn_update, download_btn_update, download_output_update, status_update, *row_updates = load_captioning(files)
394
-
395
- # Then get image updates
396
- image_updates = update_images(image_paths)
397
-
398
- # Update caption labels with filenames
399
- caption_label_updates = update_caption_labels(image_paths)
400
-
401
- # Return all updates together
402
- return [image_paths, captioning_update, caption_btn_update, download_btn_update, download_output_update, status_update] + row_updates + image_updates + caption_label_updates
403
-
404
- # Combined outputs for both functions
405
  combined_outputs = upload_outputs + image_components + caption_components
406
 
 
407
  image_upload.change(
408
- process_upload,
409
  inputs=[image_upload],
410
  outputs=combined_outputs
411
  )
@@ -417,13 +590,6 @@ with gr.Blocks() as demo:
417
  outputs=[batch_by_category]
418
  )
419
 
420
- # Manage the captioning status
421
- def on_captioning_start():
422
- return gr.update(value="⏳ Processing captions... please wait"), gr.update(interactive=False)
423
-
424
- def on_captioning_complete():
425
- return gr.update(value="✅ Captioning complete"), gr.update(interactive=True), gr.update(interactive=True)
426
-
427
  # Set up captioning button
428
  caption_btn.click(
429
  on_captioning_start,
@@ -454,5 +620,41 @@ with gr.Blocks() as demo:
454
  outputs=None
455
  )
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  if __name__ == "__main__":
 
458
  demo.launch(share=True)
 
4
  from io import BytesIO
5
  import time
6
  import tempfile
7
+ from main import collect_images_by_category, is_image_file
8
  from pathlib import Path
9
  from caption import caption_images
10
+ from PIL import Image
11
+
12
  # Maximum number of images
13
  MAX_IMAGES = 30
14
 
15
+ # ------- File Operations -------
16
+
17
  def create_download_file(image_paths, captions):
18
  """Create a zip file with images and their captions"""
19
  zip_io = BytesIO()
 
33
 
34
  return zip_io.getvalue()
35
 
36
+ def save_images_to_temp(image_paths, temp_dir):
37
+ """Copy images to temporary directory and return mapping"""
38
+ temp_image_paths = []
39
+ original_to_temp = {} # Map original paths to temp paths
40
+
41
+ for path in image_paths:
42
+ # Keep original filename to preserve categorization
43
+ filename = os.path.basename(path)
44
+ temp_path = os.path.join(temp_dir, filename)
45
+
46
+ # Ensure we're using consistent path formats
47
+ orig_path_str = str(path)
48
+ temp_path_str = str(temp_path)
49
+
50
+ with open(path, 'rb') as src, open(temp_path, 'wb') as dst:
51
+ dst.write(src.read())
52
+
53
+ temp_image_paths.append(temp_path_str)
54
+ original_to_temp[orig_path_str] = temp_path_str
55
+ print(f"Copied {orig_path_str} to {temp_path_str}")
56
+
57
+ print(f"Created {len(temp_image_paths)} temporary files")
58
+ return temp_image_paths, original_to_temp
59
+
60
+ def process_by_category(images_by_category, image_paths_by_category, image_paths, original_to_temp):
61
+ """Process images by category and map captions back to original images"""
62
+ captions = [""] * len(image_paths) # Initialize with empty strings
63
+
64
+ # Create a mapping from temp path to index in the original image_paths
65
+ temp_to_original_idx = {}
66
+ for i, orig_path in enumerate(image_paths):
67
+ if orig_path in original_to_temp:
68
+ temp_to_original_idx[original_to_temp[orig_path]] = i
69
+
70
+ print(f"Created mapping for {len(temp_to_original_idx)} images")
71
+
72
+ for category, images in images_by_category.items():
73
+ category_paths = image_paths_by_category[category]
74
+ print(f"Processing category '{category}' with {len(images)} images")
75
+
76
+ # Create mapping of image to its position in the category
77
+ category_image_map = {}
78
+ for i, (img, path) in enumerate(zip(images, category_paths)):
79
+ category_image_map[str(path)] = i
80
+
81
+ try:
82
+ # Use the same code path as CLI
83
+ category_captions = caption_images(images, category=category, batch_mode=True)
84
+ print(f"Generated {len(category_captions)} captions for category '{category}'")
85
+
86
+ # Map captions back to original paths using our direct mapping
87
+ for i, temp_path in enumerate(category_paths):
88
+ temp_path_str = str(temp_path)
89
+ if i < len(category_captions) and temp_path_str in temp_to_original_idx:
90
+ original_idx = temp_to_original_idx[temp_path_str]
91
+ captions[original_idx] = category_captions[i]
92
+ except Exception as e:
93
+ print(f"Error processing category '{category}': {e}")
94
+ # Fall back to individual processing for this category
95
+ try:
96
+ print(f"Falling back to individual processing for category '{category}'")
97
+ for i, img in enumerate(images):
98
+ if i >= len(category_paths):
99
+ continue
100
+ temp_path = category_paths[i]
101
+ temp_path_str = str(temp_path)
102
+
103
+ try:
104
+ single_captions = caption_images([img], batch_mode=False)
105
+ if single_captions and len(single_captions) > 0:
106
+ if temp_path_str in temp_to_original_idx:
107
+ original_idx = temp_to_original_idx[temp_path_str]
108
+ captions[original_idx] = single_captions[0]
109
+ except Exception as inner_e:
110
+ print(f"Error processing individual image {i} in '{category}': {inner_e}")
111
+ except Exception as fallback_e:
112
+ print(f"Error in fallback processing for '{category}': {fallback_e}")
113
+
114
+ return captions
115
+
116
+ def process_all_images(all_images, all_image_paths, image_paths):
117
+ """Process all images at once without categorization"""
118
+ print(f"Processing all {len(all_images)} images at once")
119
+
120
+ # Initialize empty captions list
121
+ captions = [""] * len(image_paths) # Initialize with empty strings for all original paths
122
+
123
+ # If there are no images, return empty captions
124
+ if not all_images:
125
+ print("No images to process, returning empty captions")
126
+ return captions
127
+
128
+ # Create a mapping from temp paths to original indexes for efficient lookup
129
+ path_to_idx = {str(path): i for i, path in enumerate(image_paths)}
130
+
131
+ try:
132
+ all_captions = caption_images(all_images, batch_mode=False)
133
+ print(f"Generated {len(all_captions)} captions")
134
+
135
+ # Map captions to the right images using the prepared image_paths
136
+ for i, (path, caption) in enumerate(zip(all_image_paths, all_captions)):
137
+ if i < len(all_captions) and path in path_to_idx:
138
+ idx = path_to_idx[path]
139
+ captions[idx] = caption
140
+ except Exception as e:
141
+ print(f"Error generating captions: {e}")
142
+
143
+ return captions
144
+
145
  def process_uploaded_images(image_paths, batch_by_category=False):
146
  """Process uploaded images using the same code path as CLI"""
147
  try:
148
+ # Convert all image paths to strings for consistency
149
+ image_paths = [str(path) for path in image_paths]
150
  print(f"Processing {len(image_paths)} images, batch_by_category={batch_by_category}")
151
+
152
+ # Create temporary directory with images
153
  with tempfile.TemporaryDirectory() as temp_dir:
154
+ # Save images to temp directory
155
+ temp_image_paths, original_to_temp = save_images_to_temp(image_paths, temp_dir)
 
 
 
 
 
 
 
 
 
 
156
 
157
+ # Use Path object for consistency with main.py
158
  temp_dir_path = Path(temp_dir)
159
 
160
+ # List files in temp directory for debugging
161
+ print(f"Files in temp directory {temp_dir}:")
162
+ for f in temp_dir_path.iterdir():
163
+ print(f" - {f} (is_file: {f.is_file()}, is_image: {is_image_file(f.name)})")
164
+
165
+ # Collect images by category using the function from main.py
166
  images_by_category, image_paths_by_category = collect_images_by_category(temp_dir_path)
167
+
168
+ # Print categories and counts for debugging
169
  print(f"Collected images into {len(images_by_category)} categories")
170
+ for category, images in images_by_category.items():
171
+ print(f" - Category '{category}': {len(images)} images")
172
+
173
+ # Check if we actually have images to process
174
+ total_images = sum(len(images) for images in images_by_category.values())
175
+ if total_images == 0:
176
+ print("No images were properly categorized. Adding all images directly.")
177
+ # Add all images directly without categorization
178
+ default_category = "default"
179
+ images_by_category[default_category] = []
180
+ image_paths_by_category[default_category] = []
181
+
182
+ for path in image_paths:
183
+ path_str = str(path)
184
+ try:
185
+ if path_str in original_to_temp:
186
+ temp_path = original_to_temp[path_str]
187
+ temp_path_obj = Path(temp_path)
188
+ img = Image.open(temp_path).convert("RGB")
189
+ images_by_category[default_category].append(img)
190
+ image_paths_by_category[default_category].append(temp_path_obj)
191
+ except Exception as e:
192
+ print(f"Error loading image {path}: {e}")
193
 
194
+ # Map back to original paths for consistent ordering
195
  all_images = []
196
  all_image_paths = []
197
+
198
+ # Create reverse mapping for lookup
199
+ temp_to_orig = {v: k for k, v in original_to_temp.items()}
200
+
201
+ # Go through each category and map back to original
202
+ for category in images_by_category:
203
+ for i, temp_path in enumerate(image_paths_by_category[category]):
204
+ temp_path_str = str(temp_path)
205
+ if temp_path_str in temp_to_orig:
206
+ orig_path = temp_to_orig[temp_path_str]
207
+ if i < len(images_by_category[category]):
208
+ all_images.append(images_by_category[category][i])
209
+ all_image_paths.append(orig_path)
210
+
211
+ # Ensure we maintain original order
212
+ ordered_images = []
213
+ ordered_paths = []
214
+
215
+ for orig_path in image_paths:
216
+ path_str = str(orig_path)
217
+ for i, path in enumerate(all_image_paths):
218
+ if path == path_str and i < len(all_images):
219
+ ordered_images.append(all_images[i])
220
+ ordered_paths.append(path)
221
  break
 
 
222
 
223
+ print(f"Collected {len(ordered_images)} images in correct order")
224
 
225
  # Process based on batch setting
226
+ if batch_by_category and len(images_by_category) > 0:
227
+ captions = process_by_category(images_by_category, image_paths_by_category, image_paths, original_to_temp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  else:
229
+ # Use our own function for non-batch mode since it needs to map back to UI
230
+ captions = process_all_images(ordered_images, ordered_paths, image_paths)
 
 
 
 
 
 
231
 
232
  print(f"Returning {len(captions)} captions")
 
233
  return captions
234
 
235
  except Exception as e:
236
  print(f"Error in processing: {e}")
237
  raise
238
 
239
+ # ------- UI Interaction Functions -------
240
+
241
+ def load_captioning(files):
242
+ """Process uploaded images and show them in the UI"""
243
+ if not files:
244
+ return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="Upload images to begin"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
245
+
246
+ # Filter to only keep image files
247
+ image_paths = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'))]
248
+
249
+ if not image_paths or len(image_paths) < 1:
250
+ gr.Warning(f"Please upload at least one image")
251
+ return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="No valid images found"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
252
+
253
+ if len(image_paths) > MAX_IMAGES:
254
+ gr.Warning(f"Only the first {MAX_IMAGES} images will be processed")
255
+ image_paths = image_paths[:MAX_IMAGES]
256
+
257
+ # Update row visibility
258
+ row_updates = []
259
+ for i in range(MAX_IMAGES):
260
+ if i < len(image_paths):
261
+ row_updates.append(gr.update(visible=True))
262
+ else:
263
+ row_updates.append(gr.update(visible=False))
264
+
265
+ return (
266
+ image_paths, # stored_image_paths
267
+ gr.update(visible=True), # captioning_area
268
+ gr.update(interactive=True), # caption_btn
269
+ gr.update(interactive=False), # download_btn - initially disabled until captioning is done
270
+ gr.update(visible=False), # download_output
271
+ gr.update(value=f"{len(image_paths)} images ready for captioning"), # status_text
272
+ *row_updates # image_rows
273
+ )
274
+
275
+ def update_images(image_paths):
276
+ """Update the image components with the uploaded images"""
277
+ print(f"Updating images with paths: {image_paths}")
278
+ updates = []
279
+ for i in range(MAX_IMAGES):
280
+ if i < len(image_paths):
281
+ updates.append(gr.update(value=image_paths[i]))
282
+ else:
283
+ updates.append(gr.update(value=None))
284
+ return updates
285
+
286
+ def update_caption_labels(image_paths):
287
+ """Update caption labels to include the image filename"""
288
+ updates = []
289
+ for i in range(MAX_IMAGES):
290
+ if i < len(image_paths):
291
+ filename = os.path.basename(image_paths[i])
292
+ updates.append(gr.update(label=filename))
293
+ else:
294
+ updates.append(gr.update(label=""))
295
+ return updates
296
+
297
+ def run_captioning(image_paths, batch_category):
298
+ """Generate captions for the images using the CLI code path"""
299
+ if not image_paths:
300
+ return [gr.update(value="") for _ in range(MAX_IMAGES)] + [gr.update(value="No images to process")]
301
 
302
+ try:
303
+ print(f"Starting captioning for {len(image_paths)} images, batch_by_category={batch_category}")
304
+ captions = process_uploaded_images(image_paths, batch_category)
305
+
306
+ # Count valid captions
307
+ valid_captions = sum(1 for c in captions if c and c.strip())
308
+ print(f"Generated {valid_captions} valid captions out of {len(captions)} images")
309
+
310
+ if valid_captions < len(captions):
311
+ gr.Warning(f"{len(captions) - valid_captions} images could not be captioned properly")
312
+ status = gr.update(value=f"✅ Captioning complete - {valid_captions}/{len(captions)} successful")
313
+ else:
314
+ gr.Info("Captioning complete!")
315
+ status = gr.update(value="✅ Captioning complete")
316
+
317
+ print("Sample captions:", captions[:2] if len(captions) >= 2 else captions)
318
+ except Exception as e:
319
+ print(f"Error in captioning: {str(e)}")
320
+ gr.Error(f"Captioning failed: {str(e)}")
321
+ captions = [""] * len(image_paths) # Use empty strings
322
+ status = gr.update(value=f"❌ Error: {str(e)}")
323
+
324
+ # Update caption textboxes
325
+ caption_updates = []
326
+ for i in range(MAX_IMAGES):
327
+ if i < len(captions) and captions[i]: # Only set value if we have a valid caption
328
+ caption_updates.append(gr.update(value=captions[i]))
329
+ else:
330
+ caption_updates.append(gr.update(value=""))
331
+
332
+ print(f"Returning {len(caption_updates)} caption updates")
333
+ return caption_updates + [status]
334
+
335
+ def update_batch_setting(value):
336
+ """Update the batch by category setting"""
337
+ return value
338
+
339
+ def create_zip_from_ui(image_paths, *captions_list):
340
+ """Create a zip file from the current images and captions in the UI"""
341
+ # Filter out empty captions for non-existent images
342
+ valid_captions = [cap for i, cap in enumerate(captions_list) if i < len(image_paths) and cap]
343
+ valid_image_paths = image_paths[:len(valid_captions)]
344
+
345
+ if not valid_image_paths:
346
+ gr.Warning("No images to download")
347
+ return None
348
+
349
+ # Create zip file
350
+ zip_data = create_download_file(valid_image_paths, valid_captions)
351
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
352
+
353
+ # Create a temporary file to store the zip
354
+ temp_dir = tempfile.gettempdir()
355
+ zip_filename = f"image_captions_{timestamp}.zip"
356
+ zip_path = os.path.join(temp_dir, zip_filename)
357
+
358
+ # Write the zip data to the temporary file
359
+ with open(zip_path, "wb") as f:
360
+ f.write(zip_data)
361
+
362
+ # Return the path to the temporary file
363
+ return zip_path
364
+
365
+ def process_upload(files, image_rows, image_components, caption_components):
366
+ """Process uploaded files and update UI components"""
367
+ # First get paths and visibility updates
368
+ image_paths, captioning_update, caption_btn_update, download_btn_update, download_output_update, status_update, *row_updates = load_captioning(files)
369
+
370
+ # Then get image updates
371
+ image_updates = update_images(image_paths)
372
 
373
+ # Update caption labels with filenames
374
+ caption_label_updates = update_caption_labels(image_paths)
375
+
376
+ # Return all updates together
377
+ return [image_paths, captioning_update, caption_btn_update, download_btn_update, download_output_update, status_update] + row_updates + image_updates + caption_label_updates
378
+
379
+ def on_captioning_start():
380
+ """Update UI when captioning starts"""
381
+ return gr.update(value="⏳ Processing captions... please wait"), gr.update(interactive=False)
382
+
383
+ def on_captioning_complete():
384
+ """Update UI when captioning completes"""
385
+ return gr.update(value="✅ Captioning complete"), gr.update(interactive=True), gr.update(interactive=True)
386
+
387
+ # ------- UI Style Definitions -------
388
+
389
+ def get_css_styles():
390
+ """Return CSS styles for the UI"""
391
+ return """
392
  <style>
393
  /* Unified styling for the two-column layout */
394
  #left-column, #right-column {
 
455
  .download-section {
456
  margin-top: 10px;
457
  }
458
+
459
+ /* Category info */
460
+ .category-info {
461
+ font-size: 0.9em;
462
+ color: #555;
463
+ background-color: #f8f9fa;
464
+ padding: 8px;
465
+ border-radius: 4px;
466
+ margin-bottom: 10px;
467
+ border-left: 3px solid #4CAF50;
468
+ }
469
  </style>
470
+ """
471
+
472
+ # ------- UI Component Creation -------
473
+
474
+ def create_upload_area():
475
+ """Create the upload area components"""
476
+ # Left column for images/upload
477
+ with gr.Column(scale=1, elem_id="left-column") as upload_column:
478
+ # Upload area
479
+ gr.Markdown("### Upload your images", elem_id="upload-heading")
480
+ gr.Markdown("Only .png, .jpg, .jpeg, and .webp files are supported", elem_id="file-types-info", elem_classes="file-types-info")
481
+ image_upload = gr.File(
482
+ file_count="multiple",
483
+ label="Drop your files here",
484
+ file_types=["image"],
485
+ type="filepath",
486
+ height=220,
487
+ elem_classes="file-upload-container",
488
+ )
489
+
490
+ return upload_column, image_upload
491
+
492
+ def create_config_area():
493
+ """Create the configuration area components"""
494
+ # Right column for configuration and captions
495
+ with gr.Column(scale=1.5, elem_id="right-column") as config_column:
496
+ # Configuration area
497
+ gr.Markdown("### Configuration")
498
+ batch_category_checkbox = gr.Checkbox(
499
+ label="Batch process by category",
500
+ value=False,
501
+ info="Caption similar images together"
502
+ )
503
+
504
+ gr.Markdown("""
505
+ **Note about categorization:**
506
+ - Images are grouped by the part of the filename before the last underscore
507
+ - For example: "character_pose_01.png" and "character_pose_02.png" share the category "character_pose"
508
+ - When using "Batch by category", similar images are captioned together for more consistent results
509
+ """, elem_classes=["category-info"])
510
+
511
+ caption_btn = gr.Button("Caption Images", variant="primary", interactive=False)
512
+ download_btn = gr.Button("Download Images + Captions", variant="secondary", interactive=False)
513
+ download_output = gr.File(label="Download Zip", visible=False)
514
+ status_text = gr.Markdown("Upload images to begin", visible=True)
515
 
516
+ return config_column, batch_category_checkbox, caption_btn, download_btn, download_output, status_text
517
+
518
+ def create_captioning_area():
519
+ """Create the captioning area components"""
520
  with gr.Column(visible=False) as captioning_area:
521
  # Replace the single heading with a row containing two headings
522
  with gr.Row():
 
555
  )
556
  caption_components.append(caption)
557
 
558
+ return captioning_area, image_rows, image_components, caption_components
559
+
560
+ def setup_event_handlers(
561
+ image_upload, stored_image_paths, captioning_area, caption_btn, download_btn,
562
+ download_output, status_text, image_rows, image_components, caption_components,
563
+ batch_category_checkbox, batch_by_category
564
+ ):
565
+ """Set up all event handlers for the UI"""
566
+ # Combined outputs for the upload function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  upload_outputs = [
568
  stored_image_paths,
569
  captioning_area,
 
574
  *image_rows
575
  ]
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  combined_outputs = upload_outputs + image_components + caption_components
578
 
579
+ # Set up upload handler
580
  image_upload.change(
581
+ lambda files: process_upload(files, image_rows, image_components, caption_components),
582
  inputs=[image_upload],
583
  outputs=combined_outputs
584
  )
 
590
  outputs=[batch_by_category]
591
  )
592
 
 
 
 
 
 
 
 
593
  # Set up captioning button
594
  caption_btn.click(
595
  on_captioning_start,
 
620
  outputs=None
621
  )
622
 
623
+ # ------- Main Application -------
624
+
625
+ def build_ui():
626
+ """Build and return the Gradio interface"""
627
+ with gr.Blocks() as demo:
628
+ gr.Markdown("# Image Auto-captioner for LoRA Training")
629
+
630
+ # Store uploaded images
631
+ stored_image_paths = gr.State([])
632
+ batch_by_category = gr.State(False) # State to track if batch by category is enabled
633
+
634
+ # Create a two-column layout for the entire interface
635
+ with gr.Row():
636
+ # Create upload area in left column
637
+ _, image_upload = create_upload_area()
638
+
639
+ # Create config area in right column
640
+ _, batch_category_checkbox, caption_btn, download_btn, download_output, status_text = create_config_area()
641
+
642
+ # Add CSS styling
643
+ gr.HTML(get_css_styles())
644
+
645
+ # Create captioning area (initially hidden)
646
+ captioning_area, image_rows, image_components, caption_components = create_captioning_area()
647
+
648
+ # Set up event handlers
649
+ setup_event_handlers(
650
+ image_upload, stored_image_paths, captioning_area, caption_btn, download_btn,
651
+ download_output, status_text, image_rows, image_components, caption_components,
652
+ batch_category_checkbox, batch_by_category
653
+ )
654
+
655
+ return demo
656
+
657
+ # Launch the app when run directly
658
  if __name__ == "__main__":
659
+ demo = build_ui()
660
  demo.launch(share=True)