Rishi Desai commited on
Commit
dd4f3b1
·
1 Parent(s): dc6215b

breaking file up

Browse files
Files changed (1) hide show
  1. caption.py +127 -103
caption.py CHANGED
@@ -62,124 +62,148 @@ class CaptioningError(Exception):
62
  """Exception raised for errors in the captioning process."""
63
  pass
64
 
65
- def caption_images(images, category=None, batch_mode=False):
66
- # Convert PIL images to base64 encoded strings
67
  image_strings = []
68
  for image in images:
69
  buffered = io.BytesIO()
70
  image.save(buffered, format="PNG")
71
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
72
  image_strings.append(img_str)
73
-
74
- # Retrieve the API key from the environment
 
 
75
  api_key = os.environ.get("TOGETHER_API_KEY")
76
  if not api_key:
77
  raise ValueError("TOGETHER_API_KEY is not set in the environment.")
 
78
 
79
- # Pass the API key to the Together client
80
- client = Together(api_key=api_key)
81
- captions = []
 
 
 
 
 
82
 
83
- # If batch_mode is True, process all images in a single API call
84
- if batch_mode and category:
85
- # Create a content array with all images
86
- content = [{"type": "text", "text": f"Here is the batch of images for {category}. Please caption each image on a separate line, starting each caption with 'tr1gger'."}]
87
-
88
- # Add all images to the content array
89
- for i, img_str in enumerate(image_strings):
90
- content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})
91
- content.append({"type": "text", "text": f"Image {i+1}"})
92
-
93
- # Send the batch request
94
- messages = [
95
- {"role": "system", "content": get_prompt()},
96
- {"role": "user", "content": content}
97
- ]
98
-
99
- response = client.chat.completions.create(
100
- model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
101
- messages=messages
102
- )
103
-
104
- # Parse the response to extract captions for each image
105
- full_response = response.choices[0].message.content.strip()
106
- lines = full_response.splitlines()
107
-
108
- # Extract captions from the response
109
- image_count = len(image_strings)
110
- captions = [""] * image_count # Initialize with empty strings
111
-
112
- # Extract lines that start with or contain "tr1gger"
113
- tr1gger_lines = [line for line in lines if "tr1gger" in line]
114
-
115
- # Assign captions to images
116
- for i in range(image_count):
117
- if i < len(tr1gger_lines):
118
- caption = tr1gger_lines[i]
119
- # If caption contains but doesn't start with tr1gger, extract just that part
120
- if not caption.startswith("tr1gger") and "tr1gger" in caption:
121
- caption = caption[caption.index("tr1gger"):]
122
- captions[i] = caption
123
-
124
- # Check if all captions are empty or don't contain the trigger word
125
- valid_captions = [c for c in captions if c and "tr1gger" in c]
126
- if not valid_captions:
127
- error_msg = "Failed to parse any valid captions from batch response. Response contained no lines with 'tr1gger'"
128
- error_msg += f"\n\nActual response:\n{full_response}"
129
- raise CaptioningError(error_msg)
130
-
131
- # Check if some captions are missing
132
- if len(valid_captions) < len(images):
133
- missing_count = len(images) - len(valid_captions)
134
- invalid_captions = [(i, c) for i, c in enumerate(captions) if not c or "tr1gger" not in c]
135
- error_msg = f"Failed to parse captions for {missing_count} of {len(images)} images in batch mode"
136
- error_msg += "\n\nMalformed captions:"
137
- for idx, caption in invalid_captions:
138
- error_msg += f"\nImage {idx+1}: '{caption}'"
139
- raise CaptioningError(error_msg)
140
- else:
141
- # Original method: process each image separately
142
- for img_str in image_strings:
143
- messages = [
144
- {"role": "system", "content": get_prompt()},
145
- {
146
- "role": "user",
147
- "content": [
148
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}},
149
- {"type": "text", "text": "Describe this image."}
150
- ]
151
- }
152
  ]
153
-
154
- # Request caption for the image using Llama 4 Maverick
155
- response = client.chat.completions.create(
156
- model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
157
- messages=messages
158
- )
159
-
160
- # Extract caption from the response
161
- full_response = response.choices[0].message.content.strip()
162
- # Post-process to extract only the caption part
163
- caption = ""
164
- for line in full_response.splitlines():
165
- if "tr1gger" in line:
166
- # If caption has a numbered prefix, extract just the caption part
167
- if not line.startswith("tr1gger"):
168
- caption = line[line.index("tr1gger"):]
169
- else:
170
- caption = line
171
- break
172
-
173
- # Check if caption is valid
174
- if not caption:
175
- error_msg = "Failed to extract a valid caption (containing 'tr1gger') from the response"
176
- error_msg += f"\n\nActual response:\n{full_response}"
177
- raise CaptioningError(error_msg)
178
-
179
- captions.append(caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  return captions
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def extract_captions(file_path):
184
  captions = []
185
  with open(file_path, 'r') as file:
 
62
  """Exception raised for errors in the captioning process."""
63
  pass
64
 
65
+ def images_to_base64(images):
66
+ """Convert a list of PIL images to base64 encoded strings."""
67
  image_strings = []
68
  for image in images:
69
  buffered = io.BytesIO()
70
  image.save(buffered, format="PNG")
71
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
72
  image_strings.append(img_str)
73
+ return image_strings
74
+
75
+ def get_together_client():
76
+ """Initialize and return the Together API client."""
77
  api_key = os.environ.get("TOGETHER_API_KEY")
78
  if not api_key:
79
  raise ValueError("TOGETHER_API_KEY is not set in the environment.")
80
+ return Together(api_key=api_key)
81
 
82
+ def extract_trigger_caption(line):
83
+ """Extract 'tr1gger' caption from a line of text."""
84
+ if "tr1gger" in line:
85
+ # If caption doesn't start with tr1gger but contains it, extract just that part
86
+ if not line.startswith("tr1gger"):
87
+ return line[line.index("tr1gger"):]
88
+ return line
89
+ return ""
90
 
91
+ def caption_single_image(client, img_str):
92
+ """Process and caption a single image."""
93
+ messages = [
94
+ {"role": "system", "content": get_prompt()},
95
+ {
96
+ "role": "user",
97
+ "content": [
98
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}},
99
+ {"type": "text", "text": "Caption this image."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  ]
101
+ }
102
+ ]
103
+
104
+ # Request caption for the image using Llama 4 Maverick
105
+ response = client.chat.completions.create(
106
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
107
+ messages=messages
108
+ )
109
+
110
+ # Extract caption from the response
111
+ full_response = response.choices[0].message.content.strip()
112
+
113
+ # Look for the trigger line in the response
114
+ caption = ""
115
+ for line in full_response.splitlines():
116
+ caption = extract_trigger_caption(line)
117
+ if caption:
118
+ break
119
+
120
+ # Check if caption is valid
121
+ if not caption:
122
+ error_msg = "Failed to extract a valid caption (containing 'tr1gger') from the response"
123
+ error_msg += f"\n\nActual response:\n{full_response}"
124
+ raise CaptioningError(error_msg)
125
+
126
+ return caption
127
+
128
+ def caption_batch_images(client, image_strings, category):
129
+ """Process and caption multiple images in a single batch request."""
130
+ # Create a content array with all images
131
+ content = [{"type": "text", "text": f"Here is the batch of images for {category}. Please caption each image on a separate line, starting each caption with 'tr1gger'."}]
132
+
133
+ # Add all images to the content array
134
+ for i, img_str in enumerate(image_strings):
135
+ content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})
136
+ content.append({"type": "text", "text": f"Image {i+1}"})
137
+
138
+ # Send the batch request
139
+ messages = [
140
+ {"role": "system", "content": get_prompt()},
141
+ {"role": "user", "content": content}
142
+ ]
143
 
144
+ response = client.chat.completions.create(
145
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
146
+ messages=messages
147
+ )
148
+
149
+ return process_batch_response(response, image_strings)
150
+
151
+ def process_batch_response(response, image_strings):
152
+ """Process the API response from a batch request and extract captions."""
153
+ # Parse the response to extract captions for each image
154
+ full_response = response.choices[0].message.content.strip()
155
+ lines = full_response.splitlines()
156
+
157
+ # Extract captions from the response
158
+ image_count = len(image_strings)
159
+ captions = [""] * image_count # Initialize with empty strings
160
+
161
+ # Extract lines that start with or contain "tr1gger"
162
+ tr1gger_lines = [line for line in lines if "tr1gger" in line]
163
+
164
+ # Assign captions to images
165
+ for i in range(image_count):
166
+ if i < len(tr1gger_lines):
167
+ caption = extract_trigger_caption(tr1gger_lines[i])
168
+ captions[i] = caption
169
+
170
+ validate_batch_captions(captions, image_count, full_response)
171
  return captions
172
 
173
+ def validate_batch_captions(captions, image_count, full_response):
174
+ """Validate captions extracted from a batch response."""
175
+ # Check if all captions are empty or don't contain the trigger word
176
+ valid_captions = [c for c in captions if c and "tr1gger" in c]
177
+ if not valid_captions:
178
+ error_msg = "Failed to parse any valid captions from batch response. Response contained no lines with 'tr1gger'"
179
+ error_msg += f"\n\nActual response:\n{full_response}"
180
+ raise CaptioningError(error_msg)
181
+
182
+ # Check if some captions are missing
183
+ if len(valid_captions) < image_count:
184
+ missing_count = image_count - len(valid_captions)
185
+ invalid_captions = [(i, c) for i, c in enumerate(captions) if not c or "tr1gger" not in c]
186
+ error_msg = f"Failed to parse captions for {missing_count} of {image_count} images in batch mode"
187
+ error_msg += "\n\nMalformed captions:"
188
+ for idx, caption in invalid_captions:
189
+ error_msg += f"\nImage {idx+1}: '{caption}'"
190
+ raise CaptioningError(error_msg)
191
+
192
+ def caption_images(images, category=None, batch_mode=False):
193
+ """Caption a list of images, either individually or in batch mode."""
194
+ # Convert PIL images to base64 encoded strings
195
+ image_strings = images_to_base64(images)
196
+
197
+ # Initialize the API client
198
+ client = get_together_client()
199
+
200
+ # Process images based on the mode
201
+ if batch_mode and category:
202
+ return caption_batch_images(client, image_strings, category)
203
+ else:
204
+ # Process each image individually
205
+ return [caption_single_image(client, img_str) for img_str in image_strings]
206
+
207
  def extract_captions(file_path):
208
  captions = []
209
  with open(file_path, 'r') as file: