Rishi Desai commited on
Commit
ca96bd8
·
1 Parent(s): a446ad0

added batching by category

Browse files
Files changed (1) hide show
  1. main.py +45 -19
main.py CHANGED
@@ -50,7 +50,7 @@ def validate_input_directory(input_dir):
50
  print(f" - {file}")
51
  sys.exit(1)
52
 
53
- def process_images(input_dir, output_dir, fix_outfit=False):
54
  """Process all images in the input directory and generate captions."""
55
  input_path = Path(input_dir)
56
  output_path = Path(output_dir) if output_dir else input_path
@@ -64,9 +64,9 @@ def process_images(input_dir, output_dir, fix_outfit=False):
64
  # Track the number of processed images
65
  processed_count = 0
66
 
67
- # Collect all images into a list
68
- images = []
69
- image_paths = []
70
 
71
  # Get all files in the input directory
72
  for file_path in input_path.iterdir():
@@ -74,26 +74,54 @@ def process_images(input_dir, output_dir, fix_outfit=False):
74
  try:
75
  # Load the image
76
  image = Image.open(file_path).convert("RGB")
77
- images.append(image)
78
- image_paths.append(file_path)
 
 
 
 
 
 
 
 
 
79
  except Exception as e:
80
  print(f"Error loading {file_path.name}: {e}")
81
 
82
  # Log the number of images found
83
- print(f"Found {len(images)} images to process.")
 
84
 
85
- if not images:
86
  print("No valid images found to process.")
87
  return
88
 
89
- # Generate captions for all images
90
- try:
91
- captions = caption_images(images)
92
- except Exception as e:
93
- print(f"Error generating captions: {e}")
94
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- # Write captions to files
 
 
 
97
  for file_path, caption in zip(image_paths, captions):
98
  try:
99
  # Create caption file path (same name but with .txt extension)
@@ -111,18 +139,16 @@ def process_images(input_dir, output_dir, fix_outfit=False):
111
  # Copy caption to output directory
112
  shutil.copy2(caption_path, output_path / caption_filename)
113
 
114
- processed_count += 1
115
  print(f"Processed {file_path.name} → {caption_filename}")
116
  except Exception as e:
117
  print(f"Error processing {file_path.name}: {e}")
118
 
119
- print(f"\nProcessing complete. {processed_count} images were captioned.")
120
-
121
  def main():
122
  parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
123
  parser.add_argument('--input', type=str, required=True, help='Directory containing images')
124
  parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
125
  parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
 
126
 
127
  args = parser.parse_args()
128
 
@@ -132,7 +158,7 @@ def main():
132
  return
133
 
134
  # Process images
135
- process_images(args.input, args.output, args.fix_outfit)
136
 
137
  if __name__ == "__main__":
138
  main()
 
50
  print(f" - {file}")
51
  sys.exit(1)
52
 
53
+ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
54
  """Process all images in the input directory and generate captions."""
55
  input_path = Path(input_dir)
56
  output_path = Path(output_dir) if output_dir else input_path
 
64
  # Track the number of processed images
65
  processed_count = 0
66
 
67
+ # Collect all images into a dictionary grouped by category
68
+ images_by_category = {}
69
+ image_paths_by_category = {}
70
 
71
  # Get all files in the input directory
72
  for file_path in input_path.iterdir():
 
74
  try:
75
  # Load the image
76
  image = Image.open(file_path).convert("RGB")
77
+
78
+ # Determine the category from the filename
79
+ category = file_path.stem.rsplit('_', 1)[0]
80
+
81
+ # Add image to the appropriate category
82
+ if category not in images_by_category:
83
+ images_by_category[category] = []
84
+ image_paths_by_category[category] = []
85
+
86
+ images_by_category[category].append(image)
87
+ image_paths_by_category[category].append(file_path)
88
  except Exception as e:
89
  print(f"Error loading {file_path.name}: {e}")
90
 
91
  # Log the number of images found
92
+ total_images = sum(len(images) for images in images_by_category.values())
93
+ print(f"Found {total_images} images to process.")
94
 
95
+ if not total_images:
96
  print("No valid images found to process.")
97
  return
98
 
99
+ # Process images by category if batch_images is True
100
+ if batch_images:
101
+ for category, images in images_by_category.items():
102
+ image_paths = image_paths_by_category[category]
103
+ try:
104
+ # Generate captions for the entire category
105
+ captions = caption_images(images)
106
+ write_captions(image_paths, captions, input_path, output_path)
107
+ processed_count += len(images)
108
+ except Exception as e:
109
+ print(f"Error generating captions for category '{category}': {e}")
110
+ else:
111
+ # Process all images at once if batch_images is False
112
+ all_images = [img for imgs in images_by_category.values() for img in imgs]
113
+ all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
114
+ try:
115
+ captions = caption_images(all_images)
116
+ write_captions(all_image_paths, captions, input_path, output_path)
117
+ processed_count += len(all_images)
118
+ except Exception as e:
119
+ print(f"Error generating captions: {e}")
120
 
121
+ print(f"\nProcessing complete. {processed_count} images were captioned.")
122
+
123
+ def write_captions(image_paths, captions, input_path, output_path):
124
+ """Helper function to write captions to files."""
125
  for file_path, caption in zip(image_paths, captions):
126
  try:
127
  # Create caption file path (same name but with .txt extension)
 
139
  # Copy caption to output directory
140
  shutil.copy2(caption_path, output_path / caption_filename)
141
 
 
142
  print(f"Processed {file_path.name} → {caption_filename}")
143
  except Exception as e:
144
  print(f"Error processing {file_path.name}: {e}")
145
 
 
 
146
  def main():
147
  parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
148
  parser.add_argument('--input', type=str, required=True, help='Directory containing images')
149
  parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
150
  parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
151
+ parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')
152
 
153
  args = parser.parse_args()
154
 
 
158
  return
159
 
160
  # Process images
161
+ process_images(args.input, args.output, args.fix_outfit, args.batch_images)
162
 
163
  if __name__ == "__main__":
164
  main()