Rishi Desai commited on
Commit
ab00f6b
·
1 Parent(s): 8373fd9

some clean up

Browse files
Files changed (3) hide show
  1. caption.py +42 -46
  2. main.py +25 -29
  3. prompt.py +18 -26
caption.py CHANGED
@@ -2,12 +2,9 @@ import base64
2
  import io
3
  import os
4
  from together import Together
5
- from PIL import Image
6
- from dotenv import load_dotenv
7
 
8
- load_dotenv()
9
 
10
- def get_prompt():
11
  return """Automated Image Captioning (for LoRA Training)
12
 
13
  Role: You are an expert AI captioning system generating precise, structured descriptions for character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.
@@ -58,10 +55,12 @@ tr1gger photorealistic, long trench coat and combat boots, walking, determined,
58
  REMEMBER: Your response must be a single line starting with "tr1gger" and following the exact format above. No additional text, formatting, or explanations are allowed.
59
  """
60
 
 
61
  class CaptioningError(Exception):
62
  """Exception raised for errors in the captioning process."""
63
  pass
64
 
 
65
  def images_to_base64(images):
66
  """Convert a list of PIL images to base64 encoded strings."""
67
  image_strings = []
@@ -72,15 +71,17 @@ def images_to_base64(images):
72
  image_strings.append(img_str)
73
  return image_strings
74
 
 
75
  def get_together_client():
76
  """Initialize and return the Together API client."""
77
  api_key = os.environ.get("TOGETHER_API_KEY")
78
  if not api_key:
79
- raise ValueError("TOGETHER_API_KEY is not set in the environment.")
80
  return Together(api_key=api_key)
81
 
82
- def extract_trigger_caption(line):
83
- """Extract 'tr1gger' caption from a line of text."""
 
84
  if "tr1gger" in line:
85
  # If caption doesn't start with tr1gger but contains it, extract just that part
86
  if not line.startswith("tr1gger"):
@@ -88,10 +89,11 @@ def extract_trigger_caption(line):
88
  return line
89
  return ""
90
 
 
91
  def caption_single_image(client, img_str):
92
  """Process and caption a single image."""
93
  messages = [
94
- {"role": "system", "content": get_prompt()},
95
  {
96
  "role": "user",
97
  "content": [
@@ -100,85 +102,81 @@ def caption_single_image(client, img_str):
100
  ]
101
  }
102
  ]
103
-
104
  # Request caption for the image using Llama 4 Maverick
105
  response = client.chat.completions.create(
106
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
107
  messages=messages
108
  )
109
-
110
- # Extract caption from the response
111
  full_response = response.choices[0].message.content.strip()
112
-
113
- # Look for the trigger line in the response
114
  caption = ""
115
  for line in full_response.splitlines():
116
- caption = extract_trigger_caption(line)
117
  if caption:
118
  break
119
-
120
- # Check if caption is valid
121
  if not caption:
122
  error_msg = "Failed to extract a valid caption (containing 'tr1gger') from the response"
123
  error_msg += f"\n\nActual response:\n{full_response}"
124
  raise CaptioningError(error_msg)
125
-
126
  return caption
127
 
128
- def caption_batch_images(client, image_strings, category):
 
129
  """Process and caption multiple images in a single batch request."""
130
  # Create a content array with all images
131
- content = [{"type": "text", "text": f"Here is the batch of images for {category}. Please caption each image on a separate line, starting each caption with 'tr1gger'."}]
132
-
133
- # Add all images to the content array
 
134
  for i, img_str in enumerate(image_strings):
135
  content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})
136
- content.append({"type": "text", "text": f"Image {i+1}"})
137
-
138
- # Send the batch request
139
  messages = [
140
- {"role": "system", "content": get_prompt()},
141
  {"role": "user", "content": content}
142
  ]
143
-
144
  response = client.chat.completions.create(
145
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
146
  messages=messages
147
  )
148
-
149
  return process_batch_response(response, image_strings)
150
 
 
151
  def process_batch_response(response, image_strings):
152
  """Process the API response from a batch request and extract captions."""
153
- # Parse the response to extract captions for each image
154
  full_response = response.choices[0].message.content.strip()
155
  lines = full_response.splitlines()
156
-
157
  # Extract captions from the response
158
  image_count = len(image_strings)
159
- captions = [""] * image_count # Initialize with empty strings
160
-
161
  # Extract lines that start with or contain "tr1gger"
162
- tr1gger_lines = [line for line in lines if "tr1gger" in line]
163
-
164
  # Assign captions to images
165
  for i in range(image_count):
166
- if i < len(tr1gger_lines):
167
- caption = extract_trigger_caption(tr1gger_lines[i])
168
  captions[i] = caption
169
-
170
  validate_batch_captions(captions, image_count, full_response)
171
  return captions
172
 
 
173
  def validate_batch_captions(captions, image_count, full_response):
174
  """Validate captions extracted from a batch response."""
175
  # Check if all captions are empty or don't contain the trigger word
176
  valid_captions = [c for c in captions if c and "tr1gger" in c]
177
  if not valid_captions:
178
- error_msg = "Failed to parse any valid captions from batch response. Response contained no lines with 'tr1gger'"
179
  error_msg += f"\n\nActual response:\n{full_response}"
180
  raise CaptioningError(error_msg)
181
-
182
  # Check if some captions are missing
183
  if len(valid_captions) < image_count:
184
  missing_count = image_count - len(valid_captions)
@@ -186,24 +184,22 @@ def validate_batch_captions(captions, image_count, full_response):
186
  error_msg = f"Failed to parse captions for {missing_count} of {image_count} images in batch mode"
187
  error_msg += "\n\nMalformed captions:"
188
  for idx, caption in invalid_captions:
189
- error_msg += f"\nImage {idx+1}: '{caption}'"
190
  raise CaptioningError(error_msg)
191
 
 
192
  def caption_images(images, category=None, batch_mode=False):
193
  """Caption a list of images, either individually or in batch mode."""
194
- # Convert PIL images to base64 encoded strings
195
  image_strings = images_to_base64(images)
196
-
197
- # Initialize the API client
198
  client = get_together_client()
199
-
200
- # Process images based on the mode
201
  if batch_mode and category:
202
- return caption_batch_images(client, image_strings, category)
203
  else:
204
- # Process each image individually
205
  return [caption_single_image(client, img_str) for img_str in image_strings]
206
 
 
207
  def extract_captions(file_path):
208
  captions = []
209
  with open(file_path, 'r') as file:
 
2
  import io
3
  import os
4
  from together import Together
 
 
5
 
 
6
 
7
+ def get_system_prompt():
8
  return """Automated Image Captioning (for LoRA Training)
9
 
10
  Role: You are an expert AI captioning system generating precise, structured descriptions for character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.
 
55
  REMEMBER: Your response must be a single line starting with "tr1gger" and following the exact format above. No additional text, formatting, or explanations are allowed.
56
  """
57
 
58
+
59
  class CaptioningError(Exception):
60
  """Exception raised for errors in the captioning process."""
61
  pass
62
 
63
+
64
  def images_to_base64(images):
65
  """Convert a list of PIL images to base64 encoded strings."""
66
  image_strings = []
 
71
  image_strings.append(img_str)
72
  return image_strings
73
 
74
+
75
  def get_together_client():
76
  """Initialize and return the Together API client."""
77
  api_key = os.environ.get("TOGETHER_API_KEY")
78
  if not api_key:
79
+ raise ValueError("TOGETHER_API_KEY not set!")
80
  return Together(api_key=api_key)
81
 
82
+
83
+ def extract_caption(line):
84
+ """Extract caption from a line of text."""
85
  if "tr1gger" in line:
86
  # If caption doesn't start with tr1gger but contains it, extract just that part
87
  if not line.startswith("tr1gger"):
 
89
  return line
90
  return ""
91
 
92
+
93
  def caption_single_image(client, img_str):
94
  """Process and caption a single image."""
95
  messages = [
96
+ {"role": "system", "content": get_system_prompt()},
97
  {
98
  "role": "user",
99
  "content": [
 
102
  ]
103
  }
104
  ]
105
+
106
  # Request caption for the image using Llama 4 Maverick
107
  response = client.chat.completions.create(
108
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
109
  messages=messages
110
  )
111
+
 
112
  full_response = response.choices[0].message.content.strip()
 
 
113
  caption = ""
114
  for line in full_response.splitlines():
115
+ caption = extract_caption(line)
116
  if caption:
117
  break
118
+
 
119
  if not caption:
120
  error_msg = "Failed to extract a valid caption (containing 'tr1gger') from the response"
121
  error_msg += f"\n\nActual response:\n{full_response}"
122
  raise CaptioningError(error_msg)
123
+
124
  return caption
125
 
126
+
127
+ def caption_image_batch(client, image_strings, category):
128
  """Process and caption multiple images in a single batch request."""
129
  # Create a content array with all images
130
+ content = [{"type": "text",
131
+ "text": f"Here is the batch of images for {category}. "
132
+ f"Caption each image on a separate line."}]
133
+
134
  for i, img_str in enumerate(image_strings):
135
  content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})
136
+ content.append({"type": "text", "text": f"Image {i + 1}"})
137
+
 
138
  messages = [
139
+ {"role": "system", "content": get_system_prompt()},
140
  {"role": "user", "content": content}
141
  ]
 
142
  response = client.chat.completions.create(
143
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
144
  messages=messages
145
  )
 
146
  return process_batch_response(response, image_strings)
147
 
148
+
149
  def process_batch_response(response, image_strings):
150
  """Process the API response from a batch request and extract captions."""
 
151
  full_response = response.choices[0].message.content.strip()
152
  lines = full_response.splitlines()
153
+
154
  # Extract captions from the response
155
  image_count = len(image_strings)
156
+ captions = [""] * image_count
157
+
158
  # Extract lines that start with or contain "tr1gger"
159
+ caption_lines = [line for line in lines if "tr1gger" in line]
160
+
161
  # Assign captions to images
162
  for i in range(image_count):
163
+ if i < len(caption_lines):
164
+ caption = extract_caption(caption_lines[i])
165
  captions[i] = caption
166
+
167
  validate_batch_captions(captions, image_count, full_response)
168
  return captions
169
 
170
+
171
  def validate_batch_captions(captions, image_count, full_response):
172
  """Validate captions extracted from a batch response."""
173
  # Check if all captions are empty or don't contain the trigger word
174
  valid_captions = [c for c in captions if c and "tr1gger" in c]
175
  if not valid_captions:
176
+ error_msg = "Failed to parse any valid captions from batch response."
177
  error_msg += f"\n\nActual response:\n{full_response}"
178
  raise CaptioningError(error_msg)
179
+
180
  # Check if some captions are missing
181
  if len(valid_captions) < image_count:
182
  missing_count = image_count - len(valid_captions)
 
184
  error_msg = f"Failed to parse captions for {missing_count} of {image_count} images in batch mode"
185
  error_msg += "\n\nMalformed captions:"
186
  for idx, caption in invalid_captions:
187
+ error_msg += f"\nImage {idx + 1}: '{caption}'"
188
  raise CaptioningError(error_msg)
189
 
190
+
191
  def caption_images(images, category=None, batch_mode=False):
192
  """Caption a list of images, either individually or in batch mode."""
 
193
  image_strings = images_to_base64(images)
194
+
 
195
  client = get_together_client()
196
+
 
197
  if batch_mode and category:
198
+ return caption_image_batch(client, image_strings, category)
199
  else:
 
200
  return [caption_single_image(client, img_str) for img_str in image_strings]
201
 
202
+
203
  def extract_captions(file_path):
204
  captions = []
205
  with open(file_path, 'r') as file:
main.py CHANGED
@@ -6,34 +6,38 @@ from pathlib import Path
6
  from PIL import Image
7
  from caption import caption_images
8
 
 
9
  def is_image_file(filename):
10
  """Check if a file is an allowed image type."""
11
  allowed_extensions = ['.png', '.jpg', '.jpeg', '.webp']
12
  return any(filename.lower().endswith(ext) for ext in allowed_extensions)
13
 
 
14
  def is_unsupported_image(filename):
15
  """Check if a file is an image but not of an allowed type."""
16
  unsupported_extensions = ['.bmp', '.gif', '.tiff', '.tif', '.ico', '.svg']
17
  return any(filename.lower().endswith(ext) for ext in unsupported_extensions)
18
 
 
19
  def is_text_file(filename):
20
  """Check if a file is a text file."""
21
  return filename.lower().endswith('.txt')
22
 
 
23
  def validate_input_directory(input_dir):
24
  """Validate that the input directory only contains allowed image formats."""
25
  input_path = Path(input_dir)
26
-
27
  unsupported_files = []
28
  text_files = []
29
-
30
  for file_path in input_path.iterdir():
31
  if file_path.is_file():
32
  if is_unsupported_image(file_path.name):
33
  unsupported_files.append(file_path.name)
34
  elif is_text_file(file_path.name):
35
  text_files.append(file_path.name)
36
-
37
  if unsupported_files:
38
  print("Error: Unsupported image formats detected.")
39
  print("Only .png, .jpg, .jpeg, and .webp files are allowed.")
@@ -41,13 +45,14 @@ def validate_input_directory(input_dir):
41
  for file in unsupported_files:
42
  print(f" - {file}")
43
  sys.exit(1)
44
-
45
  if text_files:
46
  print("Warning: Text files detected in the input directory.")
47
  print("The following text files will be overwritten:")
48
  for file in text_files:
49
  print(f" - {file}")
50
 
 
51
  def collect_images_by_category(input_path):
52
  """Collect all valid images and group them by category."""
53
  images_by_category = {}
@@ -56,28 +61,27 @@ def collect_images_by_category(input_path):
56
  for file_path in input_path.iterdir():
57
  if file_path.is_file() and is_image_file(file_path.name):
58
  try:
59
- # Load the image
60
  image = Image.open(file_path).convert("RGB")
61
-
62
  # Determine the category from the filename
63
  category = file_path.stem.rsplit('_', 1)[0]
64
-
65
  # Add image to the appropriate category
66
  if category not in images_by_category:
67
  images_by_category[category] = []
68
  image_paths_by_category[category] = []
69
-
70
  images_by_category[category].append(image)
71
  image_paths_by_category[category].append(file_path)
72
  except Exception as e:
73
  print(f"Error loading {file_path.name}: {e}")
74
-
75
  return images_by_category, image_paths_by_category
76
 
 
77
  def process_by_category(images_by_category, image_paths_by_category, input_path, output_path):
78
  """Process images in batches by category."""
79
  processed_count = 0
80
-
81
  for category, images in images_by_category.items():
82
  image_paths = image_paths_by_category[category]
83
  try:
@@ -87,35 +91,31 @@ def process_by_category(images_by_category, image_paths_by_category, input_path,
87
  processed_count += len(images)
88
  except Exception as e:
89
  print(f"Error generating captions for category '{category}': {e}")
90
-
91
  return processed_count
92
 
 
93
  def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path):
94
  """Process all images at once."""
95
  all_images = [img for imgs in images_by_category.values() for img in imgs]
96
  all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
97
  processed_count = 0
98
-
99
  try:
100
  captions = caption_images(all_images, batch_mode=False)
101
  write_captions(all_image_paths, captions, input_path, output_path)
102
  processed_count += len(all_images)
103
  except Exception as e:
104
  print(f"Error generating captions: {e}")
105
-
106
  return processed_count
107
 
 
108
  def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
109
  """Process all images in the input directory and generate captions."""
110
  input_path = Path(input_dir)
111
  output_path = Path(output_dir) if output_dir else input_path
112
-
113
- # Validate the input directory first
114
  validate_input_directory(input_dir)
115
-
116
- # Create output directory if it doesn't exist
117
  os.makedirs(output_path, exist_ok=True)
118
-
119
  # Collect images by category
120
  images_by_category, image_paths_by_category = collect_images_by_category(input_path)
121
 
@@ -127,7 +127,6 @@ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
127
  print("No valid images found to process.")
128
  return
129
 
130
- # Process images based on batch setting
131
  if batch_images:
132
  processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path)
133
  else:
@@ -135,6 +134,7 @@ def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
135
 
136
  print(f"\nProcessing complete. {processed_count} images were captioned.")
137
 
 
138
  def write_captions(image_paths, captions, input_path, output_path):
139
  """Helper function to write captions to files."""
140
  for file_path, caption in zip(image_paths, captions):
@@ -143,37 +143,33 @@ def write_captions(image_paths, captions, input_path, output_path):
143
  caption_filename = file_path.stem + ".txt"
144
  caption_path = input_path / caption_filename
145
 
146
- # Write caption to file
147
  with open(caption_path, 'w', encoding='utf-8') as f:
148
  f.write(caption)
149
 
150
  # If output directory is different from input, copy files
151
  if output_path != input_path:
152
- # Copy image to output directory
153
  shutil.copy2(file_path, output_path / file_path.name)
154
- # Copy caption to output directory
155
  shutil.copy2(caption_path, output_path / caption_filename)
156
-
157
  print(f"Processed {file_path.name} → {caption_filename}")
158
  except Exception as e:
159
  print(f"Error processing {file_path.name}: {e}")
160
 
 
161
  def main():
162
  parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
163
  parser.add_argument('--input', type=str, required=True, help='Directory containing images')
164
  parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
165
  parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
166
  parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')
167
-
168
  args = parser.parse_args()
169
-
170
- # Validate input directory
171
  if not os.path.isdir(args.input):
172
  print(f"Error: Input directory '{args.input}' does not exist.")
173
  return
174
-
175
- # Process images
176
  process_images(args.input, args.output, args.fix_outfit, args.batch_images)
177
 
 
178
  if __name__ == "__main__":
179
- main()
 
6
  from PIL import Image
7
  from caption import caption_images
8
 
9
+
10
  def is_image_file(filename):
11
  """Check if a file is an allowed image type."""
12
  allowed_extensions = ['.png', '.jpg', '.jpeg', '.webp']
13
  return any(filename.lower().endswith(ext) for ext in allowed_extensions)
14
 
15
+
16
  def is_unsupported_image(filename):
17
  """Check if a file is an image but not of an allowed type."""
18
  unsupported_extensions = ['.bmp', '.gif', '.tiff', '.tif', '.ico', '.svg']
19
  return any(filename.lower().endswith(ext) for ext in unsupported_extensions)
20
 
21
+
22
  def is_text_file(filename):
23
  """Check if a file is a text file."""
24
  return filename.lower().endswith('.txt')
25
 
26
+
27
  def validate_input_directory(input_dir):
28
  """Validate that the input directory only contains allowed image formats."""
29
  input_path = Path(input_dir)
30
+
31
  unsupported_files = []
32
  text_files = []
33
+
34
  for file_path in input_path.iterdir():
35
  if file_path.is_file():
36
  if is_unsupported_image(file_path.name):
37
  unsupported_files.append(file_path.name)
38
  elif is_text_file(file_path.name):
39
  text_files.append(file_path.name)
40
+
41
  if unsupported_files:
42
  print("Error: Unsupported image formats detected.")
43
  print("Only .png, .jpg, .jpeg, and .webp files are allowed.")
 
45
  for file in unsupported_files:
46
  print(f" - {file}")
47
  sys.exit(1)
48
+
49
  if text_files:
50
  print("Warning: Text files detected in the input directory.")
51
  print("The following text files will be overwritten:")
52
  for file in text_files:
53
  print(f" - {file}")
54
 
55
+
56
  def collect_images_by_category(input_path):
57
  """Collect all valid images and group them by category."""
58
  images_by_category = {}
 
61
  for file_path in input_path.iterdir():
62
  if file_path.is_file() and is_image_file(file_path.name):
63
  try:
 
64
  image = Image.open(file_path).convert("RGB")
65
+
66
  # Determine the category from the filename
67
  category = file_path.stem.rsplit('_', 1)[0]
68
+
69
  # Add image to the appropriate category
70
  if category not in images_by_category:
71
  images_by_category[category] = []
72
  image_paths_by_category[category] = []
73
+
74
  images_by_category[category].append(image)
75
  image_paths_by_category[category].append(file_path)
76
  except Exception as e:
77
  print(f"Error loading {file_path.name}: {e}")
78
+
79
  return images_by_category, image_paths_by_category
80
 
81
+
82
  def process_by_category(images_by_category, image_paths_by_category, input_path, output_path):
83
  """Process images in batches by category."""
84
  processed_count = 0
 
85
  for category, images in images_by_category.items():
86
  image_paths = image_paths_by_category[category]
87
  try:
 
91
  processed_count += len(images)
92
  except Exception as e:
93
  print(f"Error generating captions for category '{category}': {e}")
 
94
  return processed_count
95
 
96
+
97
  def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path):
98
  """Process all images at once."""
99
  all_images = [img for imgs in images_by_category.values() for img in imgs]
100
  all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
101
  processed_count = 0
 
102
  try:
103
  captions = caption_images(all_images, batch_mode=False)
104
  write_captions(all_image_paths, captions, input_path, output_path)
105
  processed_count += len(all_images)
106
  except Exception as e:
107
  print(f"Error generating captions: {e}")
 
108
  return processed_count
109
 
110
+
111
  def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
112
  """Process all images in the input directory and generate captions."""
113
  input_path = Path(input_dir)
114
  output_path = Path(output_dir) if output_dir else input_path
115
+
 
116
  validate_input_directory(input_dir)
 
 
117
  os.makedirs(output_path, exist_ok=True)
118
+
119
  # Collect images by category
120
  images_by_category, image_paths_by_category = collect_images_by_category(input_path)
121
 
 
127
  print("No valid images found to process.")
128
  return
129
 
 
130
  if batch_images:
131
  processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path)
132
  else:
 
134
 
135
  print(f"\nProcessing complete. {processed_count} images were captioned.")
136
 
137
+
138
  def write_captions(image_paths, captions, input_path, output_path):
139
  """Helper function to write captions to files."""
140
  for file_path, caption in zip(image_paths, captions):
 
143
  caption_filename = file_path.stem + ".txt"
144
  caption_path = input_path / caption_filename
145
 
 
146
  with open(caption_path, 'w', encoding='utf-8') as f:
147
  f.write(caption)
148
 
149
  # If output directory is different from input, copy files
150
  if output_path != input_path:
 
151
  shutil.copy2(file_path, output_path / file_path.name)
 
152
  shutil.copy2(caption_path, output_path / caption_filename)
 
153
  print(f"Processed {file_path.name} → {caption_filename}")
154
  except Exception as e:
155
  print(f"Error processing {file_path.name}: {e}")
156
 
157
+
158
  def main():
159
  parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
160
  parser.add_argument('--input', type=str, required=True, help='Directory containing images')
161
  parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
162
  parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
163
  parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')
164
+
165
  args = parser.parse_args()
166
+
 
167
  if not os.path.isdir(args.input):
168
  print(f"Error: Input directory '{args.input}' does not exist.")
169
  return
170
+
 
171
  process_images(args.input, args.output, args.fix_outfit, args.batch_images)
172
 
173
+
174
  if __name__ == "__main__":
175
+ main()
prompt.py CHANGED
@@ -1,23 +1,18 @@
1
  import os
2
  import argparse
3
  from pathlib import Path
4
- from caption import get_prompt, get_together_client, extract_captions
 
5
 
6
  def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
7
- """
8
- Optimize a user prompt to follow the same format as training captions.
9
 
10
  Args:
11
- user_prompt (str): The simple user prompt to optimize (e.g., "woman riding a bike")
12
  captions_dir (str, optional): Directory containing caption .txt files
13
  captions_list (list, optional): List of captions to use instead of loading from files
14
-
15
- Returns:
16
- str: The optimized prompt following the training format
17
  """
18
- # Get captions either from directory or provided list
19
  all_captions = []
20
-
21
  if captions_list:
22
  all_captions = captions_list
23
  elif captions_dir:
@@ -26,19 +21,18 @@ def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
26
  for file_path in captions_path.glob("*.txt"):
27
  captions = extract_captions(file_path)
28
  all_captions.extend(captions)
29
-
30
  if not all_captions:
31
- raise ValueError("No captions found. Please provide either caption files or a list of captions.")
32
-
33
  # Concatenate all captions with newlines
34
  captions_text = "\n".join(all_captions)
35
-
36
  client = get_together_client()
37
-
38
  messages = [
39
- {"role": "system", "content": get_prompt()},
40
  {
41
- "role": "user",
42
  "content": (
43
  f"These are all of the captions used to train the LoRA:\n\n"
44
  f"{captions_text}\n\n"
@@ -47,36 +41,34 @@ def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
47
  )
48
  }
49
  ]
50
-
51
  response = client.chat.completions.create(
52
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
53
  messages=messages
54
  )
55
-
56
  optimized_prompt = response.choices[0].message.content.strip()
57
  return optimized_prompt
58
 
 
59
  def main():
60
  parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.')
61
  parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize')
62
- parser.add_argument('--captions', type=str, help='Directory containing caption .txt files')
63
-
64
  args = parser.parse_args()
65
-
66
- if not args.captions:
67
- print("Error: --captions is required.")
68
- return
69
  if not os.path.isdir(args.captions):
70
  print(f"Error: Captions directory '{args.captions}' does not exist.")
71
  return
72
-
73
  try:
74
  optimized_prompt = optimize_prompt(args.prompt, args.captions)
75
  print("\nOptimized Prompt:")
76
  print(optimized_prompt)
77
-
78
  except Exception as e:
79
  print(f"Error optimizing prompt: {e}")
80
 
 
81
  if __name__ == "__main__":
82
  main()
 
1
  import os
2
  import argparse
3
  from pathlib import Path
4
+ from caption import get_system_prompt, get_together_client, extract_captions
5
+
6
 
7
  def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
8
+ """Optimize a user prompt to follow the same format as training captions.
 
9
 
10
  Args:
11
+ user_prompt (str): The simple user prompt to optimize
12
  captions_dir (str, optional): Directory containing caption .txt files
13
  captions_list (list, optional): List of captions to use instead of loading from files
 
 
 
14
  """
 
15
  all_captions = []
 
16
  if captions_list:
17
  all_captions = captions_list
18
  elif captions_dir:
 
21
  for file_path in captions_path.glob("*.txt"):
22
  captions = extract_captions(file_path)
23
  all_captions.extend(captions)
24
+
25
  if not all_captions:
26
+ raise ValueError("Please provide either caption files or a list of captions!")
27
+
28
  # Concatenate all captions with newlines
29
  captions_text = "\n".join(all_captions)
30
+
31
  client = get_together_client()
 
32
  messages = [
33
+ {"role": "system", "content": get_system_prompt()},
34
  {
35
+ "role": "user",
36
  "content": (
37
  f"These are all of the captions used to train the LoRA:\n\n"
38
  f"{captions_text}\n\n"
 
41
  )
42
  }
43
  ]
44
+
45
  response = client.chat.completions.create(
46
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
47
  messages=messages
48
  )
49
+
50
  optimized_prompt = response.choices[0].message.content.strip()
51
  return optimized_prompt
52
 
53
+
54
  def main():
55
  parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.')
56
  parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize')
57
+ parser.add_argument('--captions', type=str, required=True,help='Directory containing caption .txt files')
58
+
59
  args = parser.parse_args()
 
 
 
 
60
  if not os.path.isdir(args.captions):
61
  print(f"Error: Captions directory '{args.captions}' does not exist.")
62
  return
63
+
64
  try:
65
  optimized_prompt = optimize_prompt(args.prompt, args.captions)
66
  print("\nOptimized Prompt:")
67
  print(optimized_prompt)
68
+
69
  except Exception as e:
70
  print(f"Error optimizing prompt: {e}")
71
 
72
+
73
  if __name__ == "__main__":
74
  main()