File size: 6,731 Bytes
c9dac35
 
 
 
 
 
 
 
ab00f6b
c9dac35
 
 
 
 
ab00f6b
c9dac35
 
 
 
 
ab00f6b
c9dac35
 
 
 
ab00f6b
c9dac35
 
 
ab00f6b
c9dac35
 
ab00f6b
c9dac35
 
 
 
 
 
ab00f6b
c9dac35
 
 
 
 
 
 
ab00f6b
c9dac35
dc6215b
 
c9dac35
 
 
ab00f6b
dc6215b
 
ca96bd8
 
c9dac35
 
 
 
 
ab00f6b
ca96bd8
 
ab00f6b
ca96bd8
 
 
 
ab00f6b
ca96bd8
 
c9dac35
 
ab00f6b
dc6215b
 
ab00f6b
dc6215b
 
 
 
 
 
 
 
 
 
 
 
 
 
ab00f6b
dc6215b
 
 
 
 
 
 
 
 
 
 
 
 
ab00f6b
89d29bd
dc6215b
 
 
ab00f6b
dc6215b
 
ab00f6b
dc6215b
 
c9dac35
 
ca96bd8
 
c9dac35
ca96bd8
c9dac35
 
 
ca96bd8
dc6215b
ca96bd8
dc6215b
c9dac35
ca96bd8
 
ab00f6b
ca96bd8
 
c9dac35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab00f6b
c9dac35
 
 
 
ca96bd8
ab00f6b
c9dac35
ab00f6b
c9dac35
 
 
ab00f6b
89d29bd
c9dac35
ab00f6b
c9dac35
ab00f6b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import argparse
import shutil
import sys
from pathlib import Path
from PIL import Image
from caption import caption_images


def is_image_file(filename):
    """Check if a file is an allowed image type."""
    allowed_extensions = ['.png', '.jpg', '.jpeg', '.webp']
    return any(filename.lower().endswith(ext) for ext in allowed_extensions)


def is_unsupported_image(filename):
    """Check if a file is an image but not of an allowed type."""
    unsupported_extensions = ['.bmp', '.gif', '.tiff', '.tif', '.ico', '.svg']
    return any(filename.lower().endswith(ext) for ext in unsupported_extensions)


def is_text_file(filename):
    """Check if a file is a text file."""
    return filename.lower().endswith('.txt')


def validate_input_directory(input_dir):
    """Validate that the input directory only contains allowed image formats."""
    input_path = Path(input_dir)

    unsupported_files = []
    text_files = []

    for file_path in input_path.iterdir():
        if file_path.is_file():
            if is_unsupported_image(file_path.name):
                unsupported_files.append(file_path.name)
            elif is_text_file(file_path.name):
                text_files.append(file_path.name)

    if unsupported_files:
        print("Error: Unsupported image formats detected.")
        print("Only .png, .jpg, .jpeg, and .webp files are allowed.")
        print("The following files are not supported:")
        for file in unsupported_files:
            print(f"  - {file}")
        sys.exit(1)

    if text_files:
        print("Warning: Text files detected in the input directory.")
        print("The following text files will be overwritten:")
        for file in text_files:
            print(f"  - {file}")


def collect_images_by_category(input_path):
    """Collect all valid images and group them by category."""
    images_by_category = {}
    image_paths_by_category = {}

    for file_path in input_path.iterdir():
        if file_path.is_file() and is_image_file(file_path.name):
            try:
                image = Image.open(file_path).convert("RGB")

                # Determine the category from the filename
                category = file_path.stem.rsplit('_', 1)[0]

                # Add image to the appropriate category
                if category not in images_by_category:
                    images_by_category[category] = []
                    image_paths_by_category[category] = []

                images_by_category[category].append(image)
                image_paths_by_category[category].append(file_path)
            except Exception as e:
                print(f"Error loading {file_path.name}: {e}")

    return images_by_category, image_paths_by_category


def process_by_category(images_by_category, image_paths_by_category, input_path, output_path):
    """Process images in batches by category."""
    processed_count = 0
    for category, images in images_by_category.items():
        image_paths = image_paths_by_category[category]
        try:
            # Generate captions for the entire category using batch mode
            captions = caption_images(images, category=category, batch_mode=True)
            write_captions(image_paths, captions, input_path, output_path)
            processed_count += len(images)
        except Exception as e:
            print(f"Error generating captions for category '{category}': {e}")
    return processed_count


def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path):
    """Process all images at once."""
    all_images = [img for imgs in images_by_category.values() for img in imgs]
    all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
    processed_count = 0
    try:
        captions = caption_images(all_images, batch_mode=False)
        write_captions(all_image_paths, captions, input_path, output_path)
        processed_count += len(all_images)
    except Exception as e:
        print(f"Error generating captions: {e}")
    return processed_count


def process_images(input_dir, output_dir, batch_images=False):
    """Process all images in the input directory and generate captions."""
    input_path = Path(input_dir)
    output_path = Path(output_dir) if output_dir else input_path

    validate_input_directory(input_dir)
    os.makedirs(output_path, exist_ok=True)

    # Collect images by category
    images_by_category, image_paths_by_category = collect_images_by_category(input_path)

    # Log the number of images found
    total_images = sum(len(images) for images in images_by_category.values())
    print(f"Found {total_images} images to process.")

    if not total_images:
        print("No valid images found to process.")
        return

    if batch_images:
        processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path)
    else:
        processed_count = process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path)

    print(f"\nProcessing complete. {processed_count} images were captioned.")


def write_captions(image_paths, captions, input_path, output_path):
    """Helper function to write captions to files."""
    for file_path, caption in zip(image_paths, captions):
        try:
            # Create caption file path (same name but with .txt extension)
            caption_filename = file_path.stem + ".txt"
            caption_path = input_path / caption_filename

            with open(caption_path, 'w', encoding='utf-8') as f:
                f.write(caption)

            # If output directory is different from input, copy files
            if output_path != input_path:
                shutil.copy2(file_path, output_path / file_path.name)
                shutil.copy2(caption_path, output_path / caption_filename)
            print(f"Processed {file_path.name}{caption_filename}")
        except Exception as e:
            print(f"Error processing {file_path.name}: {e}")


def main():
    parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
    parser.add_argument('--input', type=str, required=True, help='Directory containing images')
    parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
    parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')

    args = parser.parse_args()

    if not os.path.isdir(args.input):
        print(f"Error: Input directory '{args.input}' does not exist.")
        return

    process_images(args.input, args.output, args.batch_images)


if __name__ == "__main__":
    main()