Spaces:
Running
Running
import os | |
import argparse | |
import shutil | |
import sys | |
from pathlib import Path | |
from PIL import Image | |
from caption import caption_images | |
def is_image_file(filename): | |
"""Check if a file is an allowed image type.""" | |
allowed_extensions = ['.png', '.jpg', '.jpeg', '.webp'] | |
return any(filename.lower().endswith(ext) for ext in allowed_extensions) | |
def is_unsupported_image(filename): | |
"""Check if a file is an image but not of an allowed type.""" | |
unsupported_extensions = ['.bmp', '.gif', '.tiff', '.tif', '.ico', '.svg'] | |
return any(filename.lower().endswith(ext) for ext in unsupported_extensions) | |
def is_text_file(filename): | |
"""Check if a file is a text file.""" | |
return filename.lower().endswith('.txt') | |
def validate_input_directory(input_dir): | |
"""Validate that the input directory only contains allowed image formats.""" | |
input_path = Path(input_dir) | |
unsupported_files = [] | |
text_files = [] | |
for file_path in input_path.iterdir(): | |
if file_path.is_file(): | |
if is_unsupported_image(file_path.name): | |
unsupported_files.append(file_path.name) | |
elif is_text_file(file_path.name): | |
text_files.append(file_path.name) | |
if unsupported_files: | |
print("Error: Unsupported image formats detected.") | |
print("Only .png, .jpg, .jpeg, and .webp files are allowed.") | |
print("The following files are not supported:") | |
for file in unsupported_files: | |
print(f" - {file}") | |
sys.exit(1) | |
if text_files: | |
print("Warning: Text files detected in the input directory.") | |
print("The following text files will be overwritten:") | |
for file in text_files: | |
print(f" - {file}") | |
def collect_images_by_category(input_path): | |
"""Collect all valid images and group them by category.""" | |
images_by_category = {} | |
image_paths_by_category = {} | |
for file_path in input_path.iterdir(): | |
if file_path.is_file() and is_image_file(file_path.name): | |
try: | |
image = Image.open(file_path).convert("RGB") | |
# Determine the category from the filename | |
category = file_path.stem.rsplit('_', 1)[0] | |
# Add image to the appropriate category | |
if category not in images_by_category: | |
images_by_category[category] = [] | |
image_paths_by_category[category] = [] | |
images_by_category[category].append(image) | |
image_paths_by_category[category].append(file_path) | |
except Exception as e: | |
print(f"Error loading {file_path.name}: {e}") | |
return images_by_category, image_paths_by_category | |
def process_by_category(images_by_category, image_paths_by_category, input_path, output_path): | |
"""Process images in batches by category.""" | |
processed_count = 0 | |
for category, images in images_by_category.items(): | |
image_paths = image_paths_by_category[category] | |
try: | |
# Generate captions for the entire category using batch mode | |
captions = caption_images(images, category=category, batch_mode=True) | |
write_captions(image_paths, captions, input_path, output_path) | |
processed_count += len(images) | |
except Exception as e: | |
print(f"Error generating captions for category '{category}': {e}") | |
return processed_count | |
def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path): | |
"""Process all images at once.""" | |
all_images = [img for imgs in images_by_category.values() for img in imgs] | |
all_image_paths = [path for paths in image_paths_by_category.values() for path in paths] | |
processed_count = 0 | |
try: | |
captions = caption_images(all_images, batch_mode=False) | |
write_captions(all_image_paths, captions, input_path, output_path) | |
processed_count += len(all_images) | |
except Exception as e: | |
print(f"Error generating captions: {e}") | |
return processed_count | |
def process_images(input_dir, output_dir, batch_images=False): | |
"""Process all images in the input directory and generate captions.""" | |
input_path = Path(input_dir) | |
output_path = Path(output_dir) if output_dir else input_path | |
validate_input_directory(input_dir) | |
os.makedirs(output_path, exist_ok=True) | |
# Collect images by category | |
images_by_category, image_paths_by_category = collect_images_by_category(input_path) | |
# Log the number of images found | |
total_images = sum(len(images) for images in images_by_category.values()) | |
print(f"Found {total_images} images to process.") | |
if not total_images: | |
print("No valid images found to process.") | |
return | |
if batch_images: | |
processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path) | |
else: | |
processed_count = process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path) | |
print(f"\nProcessing complete. {processed_count} images were captioned.") | |
def write_captions(image_paths, captions, input_path, output_path): | |
"""Helper function to write captions to files.""" | |
for file_path, caption in zip(image_paths, captions): | |
try: | |
# Create caption file path (same name but with .txt extension) | |
caption_filename = file_path.stem + ".txt" | |
caption_path = input_path / caption_filename | |
with open(caption_path, 'w', encoding='utf-8') as f: | |
f.write(caption) | |
# If output directory is different from input, copy files | |
if output_path != input_path: | |
shutil.copy2(file_path, output_path / file_path.name) | |
shutil.copy2(caption_path, output_path / caption_filename) | |
print(f"Processed {file_path.name} → {caption_filename}") | |
except Exception as e: | |
print(f"Error processing {file_path.name}: {e}") | |
def main(): | |
parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.') | |
parser.add_argument('--input', type=str, required=True, help='Directory containing images') | |
parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)') | |
parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches') | |
args = parser.parse_args() | |
if not os.path.isdir(args.input): | |
print(f"Error: Input directory '{args.input}' does not exist.") | |
return | |
process_images(args.input, args.output, args.batch_images) | |
if __name__ == "__main__": | |
main() | |