Spaces:
Running
Running
Rishi Desai
commited on
Commit
·
ca96bd8
1
Parent(s):
a446ad0
added batching by category
Browse files
main.py
CHANGED
@@ -50,7 +50,7 @@ def validate_input_directory(input_dir):
|
|
50 |
print(f" - {file}")
|
51 |
sys.exit(1)
|
52 |
|
53 |
-
def process_images(input_dir, output_dir, fix_outfit=False):
|
54 |
"""Process all images in the input directory and generate captions."""
|
55 |
input_path = Path(input_dir)
|
56 |
output_path = Path(output_dir) if output_dir else input_path
|
@@ -64,9 +64,9 @@ def process_images(input_dir, output_dir, fix_outfit=False):
|
|
64 |
# Track the number of processed images
|
65 |
processed_count = 0
|
66 |
|
67 |
-
# Collect all images into a
|
68 |
-
|
69 |
-
|
70 |
|
71 |
# Get all files in the input directory
|
72 |
for file_path in input_path.iterdir():
|
@@ -74,26 +74,54 @@ def process_images(input_dir, output_dir, fix_outfit=False):
|
|
74 |
try:
|
75 |
# Load the image
|
76 |
image = Image.open(file_path).convert("RGB")
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
except Exception as e:
|
80 |
print(f"Error loading {file_path.name}: {e}")
|
81 |
|
82 |
# Log the number of images found
|
83 |
-
|
|
|
84 |
|
85 |
-
if not
|
86 |
print("No valid images found to process.")
|
87 |
return
|
88 |
|
89 |
-
#
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
|
|
|
|
|
|
97 |
for file_path, caption in zip(image_paths, captions):
|
98 |
try:
|
99 |
# Create caption file path (same name but with .txt extension)
|
@@ -111,18 +139,16 @@ def process_images(input_dir, output_dir, fix_outfit=False):
|
|
111 |
# Copy caption to output directory
|
112 |
shutil.copy2(caption_path, output_path / caption_filename)
|
113 |
|
114 |
-
processed_count += 1
|
115 |
print(f"Processed {file_path.name} → {caption_filename}")
|
116 |
except Exception as e:
|
117 |
print(f"Error processing {file_path.name}: {e}")
|
118 |
|
119 |
-
print(f"\nProcessing complete. {processed_count} images were captioned.")
|
120 |
-
|
121 |
def main():
|
122 |
parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
|
123 |
parser.add_argument('--input', type=str, required=True, help='Directory containing images')
|
124 |
parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
|
125 |
parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
|
|
|
126 |
|
127 |
args = parser.parse_args()
|
128 |
|
@@ -132,7 +158,7 @@ def main():
|
|
132 |
return
|
133 |
|
134 |
# Process images
|
135 |
-
process_images(args.input, args.output, args.fix_outfit)
|
136 |
|
137 |
if __name__ == "__main__":
|
138 |
main()
|
|
|
50 |
print(f" - {file}")
|
51 |
sys.exit(1)
|
52 |
|
53 |
+
def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
|
54 |
"""Process all images in the input directory and generate captions."""
|
55 |
input_path = Path(input_dir)
|
56 |
output_path = Path(output_dir) if output_dir else input_path
|
|
|
64 |
# Track the number of processed images
|
65 |
processed_count = 0
|
66 |
|
67 |
+
# Collect all images into a dictionary grouped by category
|
68 |
+
images_by_category = {}
|
69 |
+
image_paths_by_category = {}
|
70 |
|
71 |
# Get all files in the input directory
|
72 |
for file_path in input_path.iterdir():
|
|
|
74 |
try:
|
75 |
# Load the image
|
76 |
image = Image.open(file_path).convert("RGB")
|
77 |
+
|
78 |
+
# Determine the category from the filename
|
79 |
+
category = file_path.stem.rsplit('_', 1)[0]
|
80 |
+
|
81 |
+
# Add image to the appropriate category
|
82 |
+
if category not in images_by_category:
|
83 |
+
images_by_category[category] = []
|
84 |
+
image_paths_by_category[category] = []
|
85 |
+
|
86 |
+
images_by_category[category].append(image)
|
87 |
+
image_paths_by_category[category].append(file_path)
|
88 |
except Exception as e:
|
89 |
print(f"Error loading {file_path.name}: {e}")
|
90 |
|
91 |
# Log the number of images found
|
92 |
+
total_images = sum(len(images) for images in images_by_category.values())
|
93 |
+
print(f"Found {total_images} images to process.")
|
94 |
|
95 |
+
if not total_images:
|
96 |
print("No valid images found to process.")
|
97 |
return
|
98 |
|
99 |
+
# Process images by category if batch_images is True
|
100 |
+
if batch_images:
|
101 |
+
for category, images in images_by_category.items():
|
102 |
+
image_paths = image_paths_by_category[category]
|
103 |
+
try:
|
104 |
+
# Generate captions for the entire category
|
105 |
+
captions = caption_images(images)
|
106 |
+
write_captions(image_paths, captions, input_path, output_path)
|
107 |
+
processed_count += len(images)
|
108 |
+
except Exception as e:
|
109 |
+
print(f"Error generating captions for category '{category}': {e}")
|
110 |
+
else:
|
111 |
+
# Process all images at once if batch_images is False
|
112 |
+
all_images = [img for imgs in images_by_category.values() for img in imgs]
|
113 |
+
all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
|
114 |
+
try:
|
115 |
+
captions = caption_images(all_images)
|
116 |
+
write_captions(all_image_paths, captions, input_path, output_path)
|
117 |
+
processed_count += len(all_images)
|
118 |
+
except Exception as e:
|
119 |
+
print(f"Error generating captions: {e}")
|
120 |
|
121 |
+
print(f"\nProcessing complete. {processed_count} images were captioned.")
|
122 |
+
|
123 |
+
def write_captions(image_paths, captions, input_path, output_path):
|
124 |
+
"""Helper function to write captions to files."""
|
125 |
for file_path, caption in zip(image_paths, captions):
|
126 |
try:
|
127 |
# Create caption file path (same name but with .txt extension)
|
|
|
139 |
# Copy caption to output directory
|
140 |
shutil.copy2(caption_path, output_path / caption_filename)
|
141 |
|
|
|
142 |
print(f"Processed {file_path.name} → {caption_filename}")
|
143 |
except Exception as e:
|
144 |
print(f"Error processing {file_path.name}: {e}")
|
145 |
|
|
|
|
|
146 |
def main():
|
147 |
parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
|
148 |
parser.add_argument('--input', type=str, required=True, help='Directory containing images')
|
149 |
parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
|
150 |
parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
|
151 |
+
parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')
|
152 |
|
153 |
args = parser.parse_args()
|
154 |
|
|
|
158 |
return
|
159 |
|
160 |
# Process images
|
161 |
+
process_images(args.input, args.output, args.fix_outfit, args.batch_images)
|
162 |
|
163 |
if __name__ == "__main__":
|
164 |
main()
|