Rishi Desai commited on
Commit
5d95766
·
1 Parent(s): 5872681

improved ergonomics of demo

Browse files
Files changed (2) hide show
  1. caption.py +1 -1
  2. demo.py +138 -58
caption.py CHANGED
@@ -6,7 +6,7 @@ from together import Together
6
  TRIGGER_WORD = "tr1gger"
7
 
8
  def get_system_prompt():
9
- return """Automated Image Captioning (for LoRA Training)
10
 
11
  Role: You are an expert AI captioning system generating precise, structured descriptions for character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.
12
 
 
6
  TRIGGER_WORD = "tr1gger"
7
 
8
  def get_system_prompt():
9
+ return f"""Automated Image Captioning (for LoRA Training)
10
 
11
  Role: You are an expert AI captioning system generating precise, structured descriptions for character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.
12
 
demo.py CHANGED
@@ -2,16 +2,11 @@ import gradio as gr
2
  import os
3
  import zipfile
4
  from io import BytesIO
5
- import PIL.Image
6
  import time
7
  import tempfile
8
- from main import process_images, collect_images_by_category, write_captions # Import the CLI functions
9
- from dotenv import load_dotenv
10
  from pathlib import Path
11
-
12
- # Load environment variables
13
- load_dotenv()
14
-
15
  # Maximum number of images
16
  MAX_IMAGES = 30
17
 
@@ -86,7 +81,6 @@ def process_uploaded_images(image_paths, batch_by_category=False):
86
  category_paths = image_paths_by_category[category]
87
  print(f"Processing category '{category}' with {len(images)} images")
88
  # Use the same code path as CLI
89
- from caption import caption_images
90
  category_captions = caption_images(images, category=category, batch_mode=True)
91
  print(f"Generated {len(category_captions)} captions for category '{category}'")
92
  print("Category captions:", category_captions) # Debug print category captions
@@ -100,8 +94,6 @@ def process_uploaded_images(image_paths, batch_by_category=False):
100
  captions[idx] = caption
101
  break
102
  else:
103
- # Process all images at once
104
- from caption import caption_images
105
  print(f"Processing all {len(all_images)} images at once")
106
  all_captions = caption_images(all_images, batch_mode=False)
107
  print(f"Generated {len(all_captions)} captions")
@@ -121,38 +113,122 @@ def process_uploaded_images(image_paths, batch_by_category=False):
121
 
122
  # Main Gradio interface
123
  with gr.Blocks() as demo:
124
- gr.Markdown("# Image Autocaptioner")
125
 
126
  # Store uploaded images
127
  stored_image_paths = gr.State([])
128
- batch_by_category = gr.State(True) # State to track if batch by category is enabled
129
 
130
- # Upload component
131
  with gr.Row():
132
- with gr.Column(scale=2):
133
- gr.Markdown("### Upload your images")
 
 
 
134
  image_upload = gr.File(
135
  file_count="multiple",
136
  label="Drop your files here",
137
  file_types=["image"],
138
- type="filepath"
 
 
139
  )
140
 
141
- with gr.Column(scale=1):
142
- autocaption_btn = gr.Button("Autocaption Images", variant="primary", interactive=False)
143
- status_text = gr.Markdown("Upload images to begin", visible=True)
 
 
 
 
 
 
144
 
145
- # Advanced settings dropdown
146
- with gr.Accordion("Advanced", open=False):
147
- batch_category_checkbox = gr.Checkbox(
148
- label="Batch by category",
149
- value=True,
150
- info="Group similar images together when processing"
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  # Create a container for the captioning area (initially hidden)
154
  with gr.Column(visible=False) as captioning_area:
155
- gr.Markdown("### Your images and captions")
 
 
 
 
 
156
 
157
  # Create individual image and caption rows
158
  image_rows = []
@@ -160,30 +236,30 @@ with gr.Blocks() as demo:
160
  caption_components = []
161
 
162
  for i in range(MAX_IMAGES):
163
- with gr.Row(visible=False) as img_row:
164
  image_rows.append(img_row)
165
 
166
- img = gr.Image(
167
- label=f"Image {i+1}",
168
- type="filepath",
169
- show_label=False,
170
- height=200,
171
- width=200,
172
- scale=1
173
- )
174
- image_components.append(img)
 
 
175
 
176
- caption = gr.Textbox(
177
- label=f"Caption {i+1}",
178
- lines=3,
179
- scale=2
180
- )
181
- caption_components.append(caption)
182
-
183
- # Add download button
184
- download_btn = gr.Button("Download Images with Captions", variant="secondary", interactive=False)
185
- download_output = gr.File(label="Download Zip", visible=False)
186
-
187
  def load_captioning(files):
188
  """Process uploaded images and show them in the UI"""
189
  if not files:
@@ -211,8 +287,8 @@ with gr.Blocks() as demo:
211
  return (
212
  image_paths, # stored_image_paths
213
  gr.update(visible=True), # captioning_area
214
- gr.update(interactive=True), # autocaption_btn
215
- gr.update(interactive=True), # download_btn
216
  gr.update(visible=False), # download_output
217
  gr.update(value=f"{len(image_paths)} images ready for captioning"), # status_text
218
  *row_updates # image_rows
@@ -304,7 +380,7 @@ with gr.Blocks() as demo:
304
  upload_outputs = [
305
  stored_image_paths,
306
  captioning_area,
307
- autocaption_btn,
308
  download_btn,
309
  download_output,
310
  status_text,
@@ -314,7 +390,7 @@ with gr.Blocks() as demo:
314
  # Update both paths and images in a single flow
315
  def process_upload(files):
316
  # First get paths and visibility updates
317
- image_paths, captioning_update, autocaption_update, download_btn_update, download_output_update, status_update, *row_updates = load_captioning(files)
318
 
319
  # Then get image updates
320
  image_updates = update_images(image_paths)
@@ -323,7 +399,7 @@ with gr.Blocks() as demo:
323
  caption_label_updates = update_caption_labels(image_paths)
324
 
325
  # Return all updates together
326
- return [image_paths, captioning_update, autocaption_update, download_btn_update, download_output_update, status_update] + row_updates + image_updates + caption_label_updates
327
 
328
  # Combined outputs for both functions
329
  combined_outputs = upload_outputs + image_components + caption_components
@@ -346,13 +422,13 @@ with gr.Blocks() as demo:
346
  return gr.update(value="⏳ Processing captions... please wait"), gr.update(interactive=False)
347
 
348
  def on_captioning_complete():
349
- return gr.update(value="✅ Captioning complete"), gr.update(interactive=True)
350
 
351
  # Set up captioning button
352
- autocaption_btn.click(
353
  on_captioning_start,
354
  inputs=None,
355
- outputs=[status_text, autocaption_btn]
356
  ).success(
357
  run_captioning,
358
  inputs=[stored_image_paths, batch_by_category],
@@ -360,7 +436,7 @@ with gr.Blocks() as demo:
360
  ).success(
361
  on_captioning_complete,
362
  inputs=None,
363
- outputs=[status_text, autocaption_btn]
364
  )
365
 
366
  # Set up download button
@@ -369,9 +445,13 @@ with gr.Blocks() as demo:
369
  inputs=[stored_image_paths] + caption_components,
370
  outputs=[download_output]
371
  ).then(
372
- lambda: gr.update(visible=True),
373
  inputs=None,
374
  outputs=[download_output]
 
 
 
 
375
  )
376
 
377
  if __name__ == "__main__":
 
2
  import os
3
  import zipfile
4
  from io import BytesIO
 
5
  import time
6
  import tempfile
7
+ from main import collect_images_by_category
 
8
  from pathlib import Path
9
+ from caption import caption_images
 
 
 
10
  # Maximum number of images
11
  MAX_IMAGES = 30
12
 
 
81
  category_paths = image_paths_by_category[category]
82
  print(f"Processing category '{category}' with {len(images)} images")
83
  # Use the same code path as CLI
 
84
  category_captions = caption_images(images, category=category, batch_mode=True)
85
  print(f"Generated {len(category_captions)} captions for category '{category}'")
86
  print("Category captions:", category_captions) # Debug print category captions
 
94
  captions[idx] = caption
95
  break
96
  else:
 
 
97
  print(f"Processing all {len(all_images)} images at once")
98
  all_captions = caption_images(all_images, batch_mode=False)
99
  print(f"Generated {len(all_captions)} captions")
 
113
 
114
  # Main Gradio interface
115
  with gr.Blocks() as demo:
116
+ gr.Markdown("# Image Auto-captioner for LoRA Training")
117
 
118
  # Store uploaded images
119
  stored_image_paths = gr.State([])
120
+ batch_by_category = gr.State(False) # State to track if batch by category is enabled
121
 
122
+ # Create a two-column layout for the entire interface
123
  with gr.Row():
124
+ # Left column for images/upload
125
+ with gr.Column(scale=1, elem_id="left-column"):
126
+ # Upload area
127
+ gr.Markdown("### Upload your images", elem_id="upload-heading")
128
+ gr.Markdown("Only .png, .jpg, .jpeg, and .webp files are supported", elem_id="file-types-info", elem_classes="file-types-info")
129
  image_upload = gr.File(
130
  file_count="multiple",
131
  label="Drop your files here",
132
  file_types=["image"],
133
+ type="filepath",
134
+ height=220,
135
+ elem_classes="file-upload-container",
136
  )
137
 
138
+ # Right column for configuration and captions
139
+ with gr.Column(scale=1.5, elem_id="right-column"):
140
+ # Configuration area
141
+ gr.Markdown("### Configuration")
142
+ batch_category_checkbox = gr.Checkbox(
143
+ label="Batch by category",
144
+ value=False,
145
+ info="Caption similar images together"
146
+ )
147
 
148
+ caption_btn = gr.Button("Caption Images", variant="primary", interactive=False)
149
+ download_btn = gr.Button("Download Images + Captions", variant="secondary", interactive=False)
150
+ download_output = gr.File(label="Download Zip", visible=False)
151
+ status_text = gr.Markdown("Upload images to begin", visible=True)
152
+
153
+ # Add unified CSS for the layout
154
+ gr.HTML("""
155
+ <style>
156
+ /* Unified styling for the two-column layout */
157
+ #left-column, #right-column {
158
+ padding: 10px;
159
+ align-self: flex-start;
160
+ }
161
+
162
+ /* Force columns to align at the top */
163
+ .gradio-row {
164
+ align-items: flex-start !important;
165
+ }
166
+
167
+ /* File upload styling */
168
+ .file-types-info {
169
+ margin-top: -10px;
170
+ font-size: 0.9em;
171
+ color: #666;
172
+ }
173
+
174
+ .file-upload-container {
175
+ width: 100%;
176
+ max-width: 100%;
177
+ }
178
+
179
+ .file-upload-container .file-preview {
180
+ max-height: 180px;
181
+ overflow-y: auto;
182
+ }
183
+
184
+ /* Image and caption rows styling */
185
+ .image-caption-row {
186
+ margin-bottom: 10px;
187
+ padding: 5px;
188
+ border-bottom: 1px solid #eee;
189
+ }
190
+
191
+ /* Make thumbnails same size */
192
+ .image-thumbnail {
193
+ height: 200px;
194
+ width: 200px;
195
+ object-fit: cover;
196
+ }
197
+
198
+ /* Center the image thumbnails */
199
+ #left-column, .image-caption-row > div:first-child {
200
+ display: flex;
201
+ justify-content: center;
202
+ align-items: center;
203
+ }
204
+
205
+ /* Ensure the image container itself is centered */
206
+ .image-thumbnail img, .image-thumbnail > div {
207
+ margin: 0 auto;
208
+ }
209
+
210
+ /* Caption text areas */
211
+ .caption-area {
212
+ height: 100%;
213
+ display: flex;
214
+ flex-direction: column;
215
+ }
216
+
217
+ /* Download section */
218
+ .download-section {
219
+ margin-top: 10px;
220
+ }
221
+ </style>
222
+ """)
223
 
224
  # Create a container for the captioning area (initially hidden)
225
  with gr.Column(visible=False) as captioning_area:
226
+ # Replace the single heading with a row containing two headings
227
+ with gr.Row():
228
+ with gr.Column(scale=1):
229
+ gr.Markdown("### Your Images", elem_id="images-heading")
230
+ with gr.Column(scale=1.5):
231
+ gr.Markdown("### Your Captions", elem_id="captions-heading")
232
 
233
  # Create individual image and caption rows
234
  image_rows = []
 
236
  caption_components = []
237
 
238
  for i in range(MAX_IMAGES):
239
+ with gr.Row(visible=False, elem_classes=["image-caption-row"]) as img_row:
240
  image_rows.append(img_row)
241
 
242
+ # Left column for image
243
+ with gr.Column(scale=1):
244
+ img = gr.Image(
245
+ label=f"Image {i+1}",
246
+ type="filepath",
247
+ show_label=False,
248
+ height=200,
249
+ width=200,
250
+ elem_classes=["image-thumbnail"]
251
+ )
252
+ image_components.append(img)
253
 
254
+ # Right column for caption
255
+ with gr.Column(scale=1.5):
256
+ caption = gr.Textbox(
257
+ label=f"Caption {i+1}",
258
+ lines=3,
259
+ elem_classes=["caption-area"]
260
+ )
261
+ caption_components.append(caption)
262
+
 
 
263
  def load_captioning(files):
264
  """Process uploaded images and show them in the UI"""
265
  if not files:
 
287
  return (
288
  image_paths, # stored_image_paths
289
  gr.update(visible=True), # captioning_area
290
+ gr.update(interactive=True), # caption_btn
291
+ gr.update(interactive=False), # download_btn - initially disabled until captioning is done
292
  gr.update(visible=False), # download_output
293
  gr.update(value=f"{len(image_paths)} images ready for captioning"), # status_text
294
  *row_updates # image_rows
 
380
  upload_outputs = [
381
  stored_image_paths,
382
  captioning_area,
383
+ caption_btn,
384
  download_btn,
385
  download_output,
386
  status_text,
 
390
  # Update both paths and images in a single flow
391
  def process_upload(files):
392
  # First get paths and visibility updates
393
+ image_paths, captioning_update, caption_btn_update, download_btn_update, download_output_update, status_update, *row_updates = load_captioning(files)
394
 
395
  # Then get image updates
396
  image_updates = update_images(image_paths)
 
399
  caption_label_updates = update_caption_labels(image_paths)
400
 
401
  # Return all updates together
402
+ return [image_paths, captioning_update, caption_btn_update, download_btn_update, download_output_update, status_update] + row_updates + image_updates + caption_label_updates
403
 
404
  # Combined outputs for both functions
405
  combined_outputs = upload_outputs + image_components + caption_components
 
422
  return gr.update(value="⏳ Processing captions... please wait"), gr.update(interactive=False)
423
 
424
  def on_captioning_complete():
425
+ return gr.update(value="✅ Captioning complete"), gr.update(interactive=True), gr.update(interactive=True)
426
 
427
  # Set up captioning button
428
+ caption_btn.click(
429
  on_captioning_start,
430
  inputs=None,
431
+ outputs=[status_text, caption_btn]
432
  ).success(
433
  run_captioning,
434
  inputs=[stored_image_paths, batch_by_category],
 
436
  ).success(
437
  on_captioning_complete,
438
  inputs=None,
439
+ outputs=[status_text, caption_btn, download_btn]
440
  )
441
 
442
  # Set up download button
 
445
  inputs=[stored_image_paths] + caption_components,
446
  outputs=[download_output]
447
  ).then(
448
+ lambda: gr.update(visible=True, elem_classes=["download-section"]),
449
  inputs=None,
450
  outputs=[download_output]
451
+ ).then(
452
+ lambda: gr.Info("Click the Download button that appeared to save your zip file"),
453
+ inputs=None,
454
+ outputs=None
455
  )
456
 
457
  if __name__ == "__main__":