Spaces:
Running
Running
File size: 25,339 Bytes
dc6215b e3a461e 69f7712 49415e1 dc6215b 49415e1 dc6215b e3a461e 49415e1 e3a461e dc6215b e3a461e dc6215b e3a461e 89d29bd e3a461e dc6215b e3a461e dc6215b 49415e1 e3a461e 49415e1 dc6215b 49415e1 5d95766 49415e1 5d95766 7b3611f 5d95766 cd217ad 5d95766 49415e1 69f7712 5d95766 49415e1 a28e4db 49415e1 dc6215b 49415e1 dc6215b 5d95766 dc6215b 5d95766 dc6215b 5d95766 dc6215b 5d95766 49415e1 69f7712 49415e1 dc6215b 5d95766 dc6215b 49415e1 dc6215b 49415e1 dc6215b 69f7712 dc6215b 5d95766 dc6215b 5d95766 dc6215b 69f7712 dc6215b 5d95766 dc6215b 5d95766 dc6215b 0f8d917 0aa8547 0f8d917 0aa8547 0f8d917 49415e1 713c829 49415e1 69f7712 49415e1 69f7712 49415e1 69f7712 0f8d917 69f7712 0f8d917 69f7712 49415e1 dc6215b 49415e1 dc6215b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 |
import gradio as gr
import os
import zipfile
from io import BytesIO
import time
import tempfile
from pathlib import Path
import shutil
from main import process_images
from prompt import optimize_prompt
# Maximum number of images
MAX_IMAGES = 30
# ------- File Operations -------
def create_download_file(image_paths, captions):
"""Create a zip file with images and their captions"""
zip_io = BytesIO()
with zipfile.ZipFile(zip_io, 'w') as zip_file:
for i, (image_path, caption) in enumerate(zip(image_paths, captions)):
# Get original filename without extension
base_name = os.path.splitext(os.path.basename(image_path))[0]
img_name = f"{base_name}.png"
caption_name = f"{base_name}.txt"
# Add image to zip
with open(image_path, 'rb') as img_file:
zip_file.writestr(img_name, img_file.read())
# Add caption to zip
zip_file.writestr(caption_name, caption)
return zip_io.getvalue()
def process_uploaded_images(image_paths, batch_by_category=False):
"""Process uploaded images using main.py's functions"""
# Create temporary directories for input and output
with tempfile.TemporaryDirectory() as temp_input_dir, tempfile.TemporaryDirectory() as temp_output_dir:
# Copy all images to the temporary input directory
temp_input_path = Path(temp_input_dir)
temp_output_path = Path(temp_output_dir)
# Map of original paths to filenames in temp dir
path_mapping = {}
for i, path in enumerate(image_paths):
# Keep original filename to preserve categorization
filename = os.path.basename(path)
temp_path = temp_input_path / filename
# Copy file to temp directory
shutil.copy2(path, temp_path)
path_mapping[str(temp_path)] = str(path)
# Process the images using main.py's function
process_images(temp_input_dir, temp_output_dir, batch_images=batch_by_category)
# Collect the captions from the output directory
captions = []
for orig_path in image_paths:
# Get the base filename without extension
base_name = os.path.splitext(os.path.basename(orig_path))[0]
caption_filename = f"{base_name}.txt"
caption_path = temp_output_path / caption_filename
# If caption file exists, read it; otherwise use empty string
if os.path.exists(caption_path):
with open(caption_path, 'r', encoding='utf-8') as f:
caption = f.read()
captions.append(caption)
else:
captions.append("")
return captions
# ------- UI Interaction Functions -------
def load_captioning(files):
"""Process uploaded images and show them in the UI"""
if not files:
return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="Upload images to begin"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
# Filter to only keep image files
image_paths = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'))]
if not image_paths or len(image_paths) < 1:
gr.Warning(f"Please upload at least one image")
return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="No valid images found"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)]
if len(image_paths) > MAX_IMAGES:
gr.Warning(f"Only the first {MAX_IMAGES} images will be processed")
image_paths = image_paths[:MAX_IMAGES]
# Update row visibility
row_updates = []
for i in range(MAX_IMAGES):
if i < len(image_paths):
row_updates.append(gr.update(visible=True))
else:
row_updates.append(gr.update(visible=False))
return (
image_paths, # stored_image_paths
gr.update(visible=True), # captioning_area
gr.update(interactive=True), # caption_btn
gr.update(interactive=False), # download_btn - initially disabled until captioning is done
gr.update(visible=False), # download_output
gr.update(value=f"{len(image_paths)} images ready for captioning"), # status_text
*row_updates # image_rows
)
def update_images(image_paths):
"""Update the image components with the uploaded images"""
print(f"Updating images with paths: {image_paths}")
updates = []
for i in range(MAX_IMAGES):
if i < len(image_paths):
updates.append(gr.update(value=image_paths[i]))
else:
updates.append(gr.update(value=None))
return updates
def update_caption_labels(image_paths):
"""Update caption labels to include the image filename"""
updates = []
for i in range(MAX_IMAGES):
if i < len(image_paths):
filename = os.path.basename(image_paths[i])
updates.append(gr.update(label=filename))
else:
updates.append(gr.update(label=""))
return updates
def run_captioning(image_paths, batch_category):
"""Generate captions for the images using the main.py functions"""
if not image_paths:
return [gr.update(value="") for _ in range(MAX_IMAGES)] + [gr.update(value="No images to process")]
try:
print(f"Starting captioning for {len(image_paths)} images, batch_by_category={batch_category}")
captions = process_uploaded_images(image_paths, batch_category)
# Count valid captions
valid_captions = sum(1 for c in captions if c and c.strip())
print(f"Generated {valid_captions} valid captions out of {len(captions)} images")
if valid_captions < len(captions):
gr.Warning(f"{len(captions) - valid_captions} images could not be captioned properly")
status = gr.update(value=f"β
Captioning complete - {valid_captions}/{len(captions)} successful")
else:
gr.Info("Captioning complete!")
status = gr.update(value="β
Captioning complete")
print("Sample captions:", captions[:2] if len(captions) >= 2 else captions)
except Exception as e:
print(f"Error in captioning: {str(e)}")
gr.Error(f"Captioning failed: {str(e)}")
captions = [""] * len(image_paths) # Use empty strings
status = gr.update(value=f"β Error: {str(e)}")
# Update caption textboxes
caption_updates = []
for i in range(MAX_IMAGES):
if i < len(captions) and captions[i]: # Only set value if we have a valid caption
caption_updates.append(gr.update(value=captions[i]))
else:
caption_updates.append(gr.update(value=""))
print(f"Returning {len(caption_updates)} caption updates")
return caption_updates + [status]
def update_batch_setting(value):
"""Update the batch by category setting"""
return value
def create_zip_from_ui(image_paths, *captions_list):
"""Create a zip file from the current images and captions in the UI"""
# Filter out empty captions for non-existent images
valid_captions = [cap for i, cap in enumerate(captions_list) if i < len(image_paths) and cap]
valid_image_paths = image_paths[:len(valid_captions)]
if not valid_image_paths:
gr.Warning("No images to download")
return None
# Create zip file
zip_data = create_download_file(valid_image_paths, valid_captions)
timestamp = time.strftime("%Y%m%d_%H%M%S")
# Create a temporary file to store the zip
temp_dir = tempfile.gettempdir()
zip_filename = f"image_captions_{timestamp}.zip"
zip_path = os.path.join(temp_dir, zip_filename)
# Write the zip data to the temporary file
with open(zip_path, "wb") as f:
f.write(zip_data)
# Return the path to the temporary file
return zip_path
def process_upload(files, image_rows, image_components, caption_components):
"""Process uploaded files and update UI components"""
# First get paths and visibility updates
image_paths, captioning_update, caption_btn_update, download_btn_update, download_output_update, status_update, *row_updates = load_captioning(files)
# Then get image updates
image_updates = update_images(image_paths)
# Update caption labels with filenames
caption_label_updates = update_caption_labels(image_paths)
# Return all updates together
return [image_paths, captioning_update, caption_btn_update, download_btn_update, download_output_update, status_update] + row_updates + image_updates + caption_label_updates
def on_captioning_start():
"""Update UI when captioning starts"""
return gr.update(value="β³ Processing captions... please wait"), gr.update(interactive=False)
def on_captioning_complete():
"""Update UI when captioning completes"""
return gr.update(value="β
Captioning complete"), gr.update(interactive=True), gr.update(interactive=True)
# ------- UI Style Definitions -------
def get_css_styles():
"""Return CSS styles for the UI"""
return """
<style>
/* Unified styling for the two-column layout */
#left-column, #right-column {
padding: 10px;
align-self: flex-start;
}
/* Force columns to align at the top */
.gradio-row {
align-items: flex-start !important;
}
/* File upload styling */
.file-types-info {
margin-top: 0px;
font-size: 0.9em;
color: #666;
}
.file-upload-container {
width: 100%;
max-width: 100%;
}
.file-upload-container .file-preview {
max-height: 180px;
overflow-y: auto;
}
/* Image and caption rows styling */
.image-caption-row {
margin-bottom: 10px;
padding: 5px;
border-bottom: 1px solid #eee;
}
/* Make thumbnails same size */
.image-thumbnail {
height: 100%;
width: 100%;
object-fit: contain;
}
/* Center the image thumbnails */
#left-column, .image-caption-row > div:first-child {
display: flex;
justify-content: center;
align-items: center;
}
/* Ensure the image container itself is centered */
.image-thumbnail img, .image-thumbnail > div {
margin: 0 auto;
}
/* Caption text areas */
.caption-area {
height: 100%;
display: flex;
flex-direction: column;
}
/* Download section */
.download-section {
margin-top: 10px;
}
/* Category info */
.category-info {
font-size: 0.9em;
color: #555;
background-color: #f8f9fa;
padding: 8px;
border-radius: 4px;
margin-bottom: 10px;
border-left: 3px solid #4CAF50;
}
/* Tab styling */
.tabs {
margin-top: 20px;
}
/* Prompt optimization tab styling */
.optimization-status {
margin-top: 10px;
padding: 8px;
border-radius: 4px;
background-color: #f8f9fa;
}
/* Input/output boxes for prompt optimization */
.prompt-box {
margin-bottom: 15px;
}
/* Make optimize button stand out */
.optimize-btn {
margin-top: 10px;
margin-bottom: 15px;
}
</style>
"""
# ------- UI Component Creation -------
def create_upload_area():
"""Create the upload area components"""
# Left column for images/upload
with gr.Column(scale=1, elem_id="left-column") as upload_column:
# Upload area
gr.Markdown("### Upload your images", elem_id="upload-heading")
gr.Markdown("Only .png, .jpg, .jpeg, and .webp files are supported", elem_id="file-types-info", elem_classes="file-types-info")
image_upload = gr.File(
file_count="multiple",
label="Drop your files here",
file_types=["image"],
type="filepath",
height=220,
elem_classes="file-upload-container",
)
return upload_column, image_upload
def create_config_area():
"""Create the configuration area components"""
# Right column for configuration and captions
with gr.Column(scale=1.5, elem_id="right-column") as config_column:
# Configuration area
gr.Markdown("### Configuration")
batch_category_checkbox = gr.Checkbox(
label="Batch process by category",
value=False,
info="Caption similar images together"
)
gr.Markdown("""
**Note about categorization:**
- Images are grouped by the part of the filename before the last underscore
- For example: "character_pose_1.png" and "character_pose_2.png" share the category "character_pose"
- When using "Batch process by category", similar images are captioned together for more consistent results
""", elem_classes=["category-info"])
caption_btn = gr.Button("Caption Images", variant="primary", interactive=False)
download_btn = gr.Button("Download Images + Captions", variant="secondary", interactive=False)
download_output = gr.File(label="Download Zip", visible=False)
status_text = gr.Markdown("Upload images to begin", visible=True)
return config_column, batch_category_checkbox, caption_btn, download_btn, download_output, status_text
def create_captioning_area():
"""Create the captioning area components"""
with gr.Column(visible=False) as captioning_area:
# Replace the single heading with a row containing two headings
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Your Images", elem_id="images-heading")
with gr.Column(scale=1.5):
gr.Markdown("### Your Captions", elem_id="captions-heading")
# Create individual image and caption rows
image_rows = []
image_components = []
caption_components = []
for i in range(MAX_IMAGES):
with gr.Row(visible=False, elem_classes=["image-caption-row"]) as img_row:
image_rows.append(img_row)
# Left column for image
with gr.Column(scale=1):
img = gr.Image(
label=f"Image {i+1}",
type="filepath",
show_label=False,
height=200,
width=200,
elem_classes=["image-thumbnail"]
)
image_components.append(img)
# Right column for caption
with gr.Column(scale=1.5):
caption = gr.Textbox(
label=f"Caption {i+1}",
lines=3,
elem_classes=["caption-area"]
)
caption_components.append(caption)
return captioning_area, image_rows, image_components, caption_components
def setup_event_handlers(
image_upload, stored_image_paths, captioning_area, caption_btn, download_btn,
download_output, status_text, image_rows, image_components, caption_components,
batch_category_checkbox, batch_by_category, shared_captions=None
):
"""Set up all event handlers for the UI"""
# Combined outputs for the upload function
upload_outputs = [
stored_image_paths,
captioning_area,
caption_btn,
download_btn,
download_output,
status_text,
*image_rows
]
combined_outputs = upload_outputs + image_components + caption_components
# Set up upload handler
image_upload.change(
lambda files: process_upload(files, image_rows, image_components, caption_components),
inputs=[image_upload],
outputs=combined_outputs
)
# Set up batch category checkbox
batch_category_checkbox.change(
update_batch_setting,
inputs=[batch_category_checkbox],
outputs=[batch_by_category]
)
# Set up captioning button chain
caption_chain = caption_btn.click(
on_captioning_start,
inputs=None,
outputs=[status_text, caption_btn]
).success(
run_captioning,
inputs=[stored_image_paths, batch_by_category],
outputs=caption_components + [status_text]
).success(
on_captioning_complete,
inputs=None,
outputs=[status_text, caption_btn, download_btn]
)
# If shared_captions is provided, add an additional handler to update it
if shared_captions is not None:
def extract_valid_captions(*caption_values):
return [c for c in caption_values if c and c.strip()]
caption_chain.success(
extract_valid_captions,
inputs=caption_components,
outputs=[shared_captions]
)
# Set up download button
download_btn.click(
create_zip_from_ui,
inputs=[stored_image_paths] + caption_components,
outputs=[download_output]
).then(
lambda: gr.update(visible=True, elem_classes=["download-section"]),
inputs=None,
outputs=[download_output]
).then(
lambda: gr.Info("Click the Download button that appeared to save your zip file"),
inputs=None,
outputs=None
)
# ------- Prompt Optimization UI -------
def create_prompt_optimization_ui():
"""Create UI components for prompt optimization tab"""
with gr.Column(scale=1) as left_column:
# Left side for caption input
gr.Markdown("### Upload Captions")
gr.Markdown("Upload caption files (.txt) or enter captions manually", elem_classes="file-types-info")
captions_upload = gr.File(
file_count="multiple",
label="Upload caption files",
file_types=[".txt"],
type="filepath",
elem_classes="file-upload-container",
height=220
)
manual_captions = gr.Textbox(
label="Or enter captions manually",
lines=5,
placeholder="Enter captions here, one per line",
elem_classes="prompt-box"
)
# Add button to use captions from image captioning tab
use_generated_captions = gr.Button("Use Captions from Manual Entry", variant="secondary")
with gr.Column(scale=1) as right_column:
# Right side for prompt input and output
gr.Markdown("### Optimize Prompt")
gr.Markdown("\n- Craft prompts that match the style of your training captions\n- Enter a simple prompt and receive an optimized version\n", elem_classes=["category-info"])
user_prompt = gr.Textbox(
label="Enter your prompt",
lines=3,
placeholder="Enter the prompt you want to optimize",
elem_classes="prompt-box"
)
optimize_btn = gr.Button("Optimize Prompt", variant="primary", elem_classes="optimize-btn")
optimized_prompt = gr.Textbox(
label="Optimized Prompt",
lines=5,
placeholder="Optimized prompt will appear here",
elem_classes="prompt-box"
)
optimization_status = gr.Markdown("Enter a prompt and upload captions to begin", elem_classes="optimization-status")
# Return components but NOT info_md (will create it separately in build_ui)
return (
left_column, right_column, captions_upload, manual_captions,
use_generated_captions, user_prompt, optimize_btn,
optimized_prompt, optimization_status
)
def run_optimization(prompt, caption_files, manual_caption_text):
"""Handle the prompt optimization logic"""
if not prompt or prompt.strip() == "":
return "", "Please enter a prompt to optimize"
# Handle different input sources for captions
caption_list = []
if manual_caption_text and manual_caption_text.strip():
# Use manually entered captions
caption_list = [line.strip() for line in manual_caption_text.split("\n") if line.strip()]
elif caption_files and len(caption_files) > 0:
# Read captions from uploaded files
for file_path in caption_files:
if os.path.exists(file_path) and file_path.lower().endswith('.txt'):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
if content:
caption_list.append(content)
if not caption_list:
return "", "Please upload caption files or enter captions manually"
try:
# Call the optimize_prompt function from prompt.py
result = optimize_prompt(prompt, captions_list=caption_list)
return result, "β
Prompt optimization complete"
except Exception as e:
return "", f"β Error optimizing prompt: {str(e)}"
def setup_prompt_optimization_handlers(
captions_upload, manual_captions, use_generated_captions,
user_prompt, optimize_btn, optimized_prompt,
optimization_status, shared_captions
):
"""Set up event handlers for prompt optimization tab"""
# Function to update manual captions with shared ones
def fill_with_shared_captions(captions_list):
if not captions_list or len(captions_list) == 0:
return "No captions available. Generate captions in the Image Captioning tab first."
return "\n".join(captions_list)
# Connect button to fill manual captions area
use_generated_captions.click(
fill_with_shared_captions,
inputs=[shared_captions],
outputs=[manual_captions]
)
# Connect the optimize button to the optimization function
optimize_btn.click(
run_optimization,
inputs=[user_prompt, captions_upload, manual_captions],
outputs=[optimized_prompt, optimization_status]
)
# ------- Main Application -------
def build_ui():
"""Build and return the Gradio interface"""
with gr.Blocks() as demo:
gr.Markdown("# Image Auto-captioner for LoRA Training")
gr.Markdown("""Check out the [code](https://github.com/RishiDesai/LoRACaptioner)
and see my [blog post](https://rishidesai.github.io/posts/character-lora/) for more information.""")
# Store generated captions for sharing between tabs
shared_captions = gr.State([])
# Create tabs for different functionality
with gr.Tabs() as tabs:
with gr.TabItem("Image Captioning") as captioning_tab:
# Store uploaded images
stored_image_paths = gr.State([])
batch_by_category = gr.State(False) # State to track if batch by category is enabled
# Create a two-column layout for the entire interface
with gr.Row():
# Create upload area in left column
_, image_upload = create_upload_area()
# Create config area in right column
_, batch_category_checkbox, caption_btn, download_btn, download_output, status_text = create_config_area()
# Create captioning area (initially hidden)
captioning_area, image_rows, image_components, caption_components = create_captioning_area()
# Set up event handlers with shared captions
setup_event_handlers(
image_upload, stored_image_paths, captioning_area, caption_btn, download_btn,
download_output, status_text, image_rows, image_components, caption_components,
batch_category_checkbox, batch_by_category, shared_captions
)
with gr.TabItem("Prompt Optimization") as prompt_tab:
with gr.Row():
# Create prompt optimization UI components
(
left_column, right_column, captions_upload, manual_captions,
use_generated_captions, user_prompt, optimize_btn,
optimized_prompt, optimization_status
) = create_prompt_optimization_ui()
# Set up prompt optimization event handlers
setup_prompt_optimization_handlers(
captions_upload, manual_captions, use_generated_captions,
user_prompt, optimize_btn, optimized_prompt,
optimization_status, shared_captions
)
# Add CSS styling
gr.HTML(get_css_styles())
return demo
# Launch the app when run directly
if __name__ == "__main__":
demo = build_ui()
demo.launch(share=True)
|