import gradio as gr import os import zipfile from io import BytesIO import PIL.Image import time import tempfile from main import process_images, collect_images_by_category, write_captions # Import the CLI functions from dotenv import load_dotenv from pathlib import Path # Load environment variables load_dotenv() # Maximum number of images MAX_IMAGES = 30 def create_download_file(image_paths, captions): """Create a zip file with images and their captions""" zip_io = BytesIO() with zipfile.ZipFile(zip_io, 'w') as zip_file: for i, (image_path, caption) in enumerate(zip(image_paths, captions)): # Get original filename without extension base_name = os.path.splitext(os.path.basename(image_path))[0] img_name = f"{base_name}.png" caption_name = f"{base_name}.txt" # Add image to zip with open(image_path, 'rb') as img_file: zip_file.writestr(img_name, img_file.read()) # Add caption to zip zip_file.writestr(caption_name, caption) return zip_io.getvalue() def process_uploaded_images(image_paths, batch_by_category=False): """Process uploaded images using the same code path as CLI""" try: print(f"Processing {len(image_paths)} images, batch_by_category={batch_by_category}") # Create a temporary directory to store the images with tempfile.TemporaryDirectory() as temp_dir: # Copy images to temp directory and maintain original order temp_image_paths = [] original_to_temp = {} # Map original paths to temp paths for path in image_paths: filename = os.path.basename(path) temp_path = os.path.join(temp_dir, filename) with open(path, 'rb') as src, open(temp_path, 'wb') as dst: dst.write(src.read()) temp_image_paths.append(temp_path) original_to_temp[path] = temp_path print(f"Created {len(temp_image_paths)} temporary files") # Convert temp_dir to Path object for collect_images_by_category temp_dir_path = Path(temp_dir) # Process images using the CLI code path images_by_category, image_paths_by_category = collect_images_by_category(temp_dir_path) print(f"Collected images into {len(images_by_category)} categories") # Get all images and paths in the correct order all_images = [] all_image_paths = [] for path in image_paths: # Use original order temp_path = original_to_temp[path] found = False for category, paths in image_paths_by_category.items(): if temp_path in [str(p) for p in paths]: # Convert Path objects to strings for comparison idx = [str(p) for p in paths].index(temp_path) all_images.append(images_by_category[category][idx]) all_image_paths.append(path) # Use original path found = True break if not found: print(f"Warning: Could not find image {path} in categorized data") print(f"Collected {len(all_images)} images in correct order") # Process based on batch setting if batch_by_category: # Process each category separately captions = [""] * len(image_paths) # Initialize with empty strings for category, images in images_by_category.items(): category_paths = image_paths_by_category[category] print(f"Processing category '{category}' with {len(images)} images") # Use the same code path as CLI from caption import caption_images category_captions = caption_images(images, category=category, batch_mode=True) print(f"Generated {len(category_captions)} captions for category '{category}'") print("Category captions:", category_captions) # Debug print category captions # Map captions back to original paths for temp_path, caption in zip(category_paths, category_captions): temp_path_str = str(temp_path) for orig_path, orig_temp in original_to_temp.items(): if orig_temp == temp_path_str: idx = image_paths.index(orig_path) captions[idx] = caption break else: # Process all images at once from caption import caption_images print(f"Processing all {len(all_images)} images at once") all_captions = caption_images(all_images, batch_mode=False) print(f"Generated {len(all_captions)} captions") print("All captions:", all_captions) # Debug print all captions captions = [""] * len(image_paths) for path, caption in zip(all_image_paths, all_captions): idx = image_paths.index(path) captions[idx] = caption print(f"Returning {len(captions)} captions") print("Final captions:", captions) # Debug print final captions return captions except Exception as e: print(f"Error in processing: {e}") raise # Main Gradio interface with gr.Blocks() as demo: gr.Markdown("# Image Autocaptioner") # Store uploaded images stored_image_paths = gr.State([]) batch_by_category = gr.State(True) # State to track if batch by category is enabled # Upload component with gr.Row(): with gr.Column(scale=2): gr.Markdown("### Upload your images") image_upload = gr.File( file_count="multiple", label="Drop your files here", file_types=["image"], type="filepath" ) with gr.Column(scale=1): autocaption_btn = gr.Button("Autocaption Images", variant="primary", interactive=False) status_text = gr.Markdown("Upload images to begin", visible=True) # Advanced settings dropdown with gr.Accordion("Advanced", open=False): batch_category_checkbox = gr.Checkbox( label="Batch by category", value=True, info="Group similar images together when processing" ) # Create a container for the captioning area (initially hidden) with gr.Column(visible=False) as captioning_area: gr.Markdown("### Your images and captions") # Create individual image and caption rows image_rows = [] image_components = [] caption_components = [] for i in range(MAX_IMAGES): with gr.Row(visible=False) as img_row: image_rows.append(img_row) img = gr.Image( label=f"Image {i+1}", type="filepath", show_label=False, height=200, width=200, scale=1 ) image_components.append(img) caption = gr.Textbox( label=f"Caption {i+1}", lines=3, scale=2 ) caption_components.append(caption) # Add download button download_btn = gr.Button("Download Images with Captions", variant="secondary", interactive=False) download_output = gr.File(label="Download Zip", visible=False) def load_captioning(files): """Process uploaded images and show them in the UI""" if not files: return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="Upload images to begin"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)] # Filter to only keep image files image_paths = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'))] if not image_paths or len(image_paths) < 1: gr.Warning(f"Please upload at least one image") return [], gr.update(visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(visible=False), gr.update(value="No valid images found"), *[gr.update(visible=False) for _ in range(MAX_IMAGES)] if len(image_paths) > MAX_IMAGES: gr.Warning(f"Only the first {MAX_IMAGES} images will be processed") image_paths = image_paths[:MAX_IMAGES] # Update row visibility row_updates = [] for i in range(MAX_IMAGES): if i < len(image_paths): row_updates.append(gr.update(visible=True)) else: row_updates.append(gr.update(visible=False)) return ( image_paths, # stored_image_paths gr.update(visible=True), # captioning_area gr.update(interactive=True), # autocaption_btn gr.update(interactive=True), # download_btn gr.update(visible=False), # download_output gr.update(value=f"{len(image_paths)} images ready for captioning"), # status_text *row_updates # image_rows ) def update_images(image_paths): """Update the image components with the uploaded images""" print(f"Updating images with paths: {image_paths}") updates = [] for i in range(MAX_IMAGES): if i < len(image_paths): updates.append(gr.update(value=image_paths[i])) else: updates.append(gr.update(value=None)) return updates def update_caption_labels(image_paths): """Update caption labels to include the image filename""" updates = [] for i in range(MAX_IMAGES): if i < len(image_paths): filename = os.path.basename(image_paths[i]) updates.append(gr.update(label=filename)) else: updates.append(gr.update(label="")) return updates def run_captioning(image_paths, batch_category): """Generate captions for the images using the CLI code path""" if not image_paths: return [gr.update(value="") for _ in range(MAX_IMAGES)] + [gr.update(value="No images to process")] try: print(f"Starting captioning for {len(image_paths)} images") captions = process_uploaded_images(image_paths, batch_category) print(f"Generated {len(captions)} captions") print("Sample captions:", captions[:2]) # Debug print first two captions gr.Info("Captioning complete!") status = gr.update(value="✅ Captioning complete") except Exception as e: print(f"Error in captioning: {str(e)}") gr.Error(f"Captioning failed: {str(e)}") captions = [f"Error: {str(e)}" for _ in image_paths] status = gr.update(value=f"❌ Error: {str(e)}") # Update caption textboxes caption_updates = [] for i in range(MAX_IMAGES): if i < len(captions): caption_updates.append(gr.update(value=captions[i])) else: caption_updates.append(gr.update(value="")) print(f"Returning {len(caption_updates)} caption updates") return caption_updates + [status] def update_batch_setting(value): """Update the batch by category setting""" return value def create_zip_from_ui(image_paths, *captions_list): """Create a zip file from the current images and captions in the UI""" # Filter out empty captions for non-existent images valid_captions = [cap for i, cap in enumerate(captions_list) if i < len(image_paths) and cap] valid_image_paths = image_paths[:len(valid_captions)] if not valid_image_paths: gr.Warning("No images to download") return None # Create zip file zip_data = create_download_file(valid_image_paths, valid_captions) timestamp = time.strftime("%Y%m%d_%H%M%S") # Create a temporary file to store the zip temp_dir = tempfile.gettempdir() zip_filename = f"image_captions_{timestamp}.zip" zip_path = os.path.join(temp_dir, zip_filename) # Write the zip data to the temporary file with open(zip_path, "wb") as f: f.write(zip_data) # Return the path to the temporary file return zip_path # Update the upload_outputs upload_outputs = [ stored_image_paths, captioning_area, autocaption_btn, download_btn, download_output, status_text, *image_rows ] # Update both paths and images in a single flow def process_upload(files): # First get paths and visibility updates image_paths, captioning_update, autocaption_update, download_btn_update, download_output_update, status_update, *row_updates = load_captioning(files) # Then get image updates image_updates = update_images(image_paths) # Update caption labels with filenames caption_label_updates = update_caption_labels(image_paths) # Return all updates together return [image_paths, captioning_update, autocaption_update, download_btn_update, download_output_update, status_update] + row_updates + image_updates + caption_label_updates # Combined outputs for both functions combined_outputs = upload_outputs + image_components + caption_components image_upload.change( process_upload, inputs=[image_upload], outputs=combined_outputs ) # Set up batch category checkbox batch_category_checkbox.change( update_batch_setting, inputs=[batch_category_checkbox], outputs=[batch_by_category] ) # Manage the captioning status def on_captioning_start(): return gr.update(value="⏳ Processing captions... please wait"), gr.update(interactive=False) def on_captioning_complete(): return gr.update(value="✅ Captioning complete"), gr.update(interactive=True) # Set up captioning button autocaption_btn.click( on_captioning_start, inputs=None, outputs=[status_text, autocaption_btn] ).success( run_captioning, inputs=[stored_image_paths, batch_by_category], outputs=caption_components + [status_text] ).success( on_captioning_complete, inputs=None, outputs=[status_text, autocaption_btn] ) # Set up download button download_btn.click( create_zip_from_ui, inputs=[stored_image_paths] + caption_components, outputs=[download_output] ).then( lambda: gr.update(visible=True), inputs=None, outputs=[download_output] ) if __name__ == "__main__": demo.launch(share=True)