LoRACaptioner / main.py
Rishi Desai
removing outfit cmd line arg
89d29bd
import os
import argparse
import shutil
import sys
from pathlib import Path
from PIL import Image
from caption import caption_images
def is_image_file(filename):
"""Check if a file is an allowed image type."""
allowed_extensions = ['.png', '.jpg', '.jpeg', '.webp']
return any(filename.lower().endswith(ext) for ext in allowed_extensions)
def is_unsupported_image(filename):
"""Check if a file is an image but not of an allowed type."""
unsupported_extensions = ['.bmp', '.gif', '.tiff', '.tif', '.ico', '.svg']
return any(filename.lower().endswith(ext) for ext in unsupported_extensions)
def is_text_file(filename):
"""Check if a file is a text file."""
return filename.lower().endswith('.txt')
def validate_input_directory(input_dir):
"""Validate that the input directory only contains allowed image formats."""
input_path = Path(input_dir)
unsupported_files = []
text_files = []
for file_path in input_path.iterdir():
if file_path.is_file():
if is_unsupported_image(file_path.name):
unsupported_files.append(file_path.name)
elif is_text_file(file_path.name):
text_files.append(file_path.name)
if unsupported_files:
print("Error: Unsupported image formats detected.")
print("Only .png, .jpg, .jpeg, and .webp files are allowed.")
print("The following files are not supported:")
for file in unsupported_files:
print(f" - {file}")
sys.exit(1)
if text_files:
print("Warning: Text files detected in the input directory.")
print("The following text files will be overwritten:")
for file in text_files:
print(f" - {file}")
def collect_images_by_category(input_path):
"""Collect all valid images and group them by category."""
images_by_category = {}
image_paths_by_category = {}
for file_path in input_path.iterdir():
if file_path.is_file() and is_image_file(file_path.name):
try:
image = Image.open(file_path).convert("RGB")
# Determine the category from the filename
category = file_path.stem.rsplit('_', 1)[0]
# Add image to the appropriate category
if category not in images_by_category:
images_by_category[category] = []
image_paths_by_category[category] = []
images_by_category[category].append(image)
image_paths_by_category[category].append(file_path)
except Exception as e:
print(f"Error loading {file_path.name}: {e}")
return images_by_category, image_paths_by_category
def process_by_category(images_by_category, image_paths_by_category, input_path, output_path):
"""Process images in batches by category."""
processed_count = 0
for category, images in images_by_category.items():
image_paths = image_paths_by_category[category]
try:
# Generate captions for the entire category using batch mode
captions = caption_images(images, category=category, batch_mode=True)
write_captions(image_paths, captions, input_path, output_path)
processed_count += len(images)
except Exception as e:
print(f"Error generating captions for category '{category}': {e}")
return processed_count
def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path):
"""Process all images at once."""
all_images = [img for imgs in images_by_category.values() for img in imgs]
all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
processed_count = 0
try:
captions = caption_images(all_images, batch_mode=False)
write_captions(all_image_paths, captions, input_path, output_path)
processed_count += len(all_images)
except Exception as e:
print(f"Error generating captions: {e}")
return processed_count
def process_images(input_dir, output_dir, batch_images=False):
"""Process all images in the input directory and generate captions."""
input_path = Path(input_dir)
output_path = Path(output_dir) if output_dir else input_path
validate_input_directory(input_dir)
os.makedirs(output_path, exist_ok=True)
# Collect images by category
images_by_category, image_paths_by_category = collect_images_by_category(input_path)
# Log the number of images found
total_images = sum(len(images) for images in images_by_category.values())
print(f"Found {total_images} images to process.")
if not total_images:
print("No valid images found to process.")
return
if batch_images:
processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path)
else:
processed_count = process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path)
print(f"\nProcessing complete. {processed_count} images were captioned.")
def write_captions(image_paths, captions, input_path, output_path):
"""Helper function to write captions to files."""
for file_path, caption in zip(image_paths, captions):
try:
# Create caption file path (same name but with .txt extension)
caption_filename = file_path.stem + ".txt"
caption_path = input_path / caption_filename
with open(caption_path, 'w', encoding='utf-8') as f:
f.write(caption)
# If output directory is different from input, copy files
if output_path != input_path:
shutil.copy2(file_path, output_path / file_path.name)
shutil.copy2(caption_path, output_path / caption_filename)
print(f"Processed {file_path.name}{caption_filename}")
except Exception as e:
print(f"Error processing {file_path.name}: {e}")
def main():
parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
parser.add_argument('--input', type=str, required=True, help='Directory containing images')
parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')
args = parser.parse_args()
if not os.path.isdir(args.input):
print(f"Error: Input directory '{args.input}' does not exist.")
return
process_images(args.input, args.output, args.batch_images)
if __name__ == "__main__":
main()