Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,011 Bytes
5a467ab 335bcd6 731975c 335bcd6 19284aa 335bcd6 5ea22b8 335bcd6 19284aa 335bcd6 5a467ab 19284aa dae4d1c 335bcd6 731975c 335bcd6 731975c 335bcd6 731975c 335bcd6 c790e67 335bcd6 c790e67 c45a224 c790e67 335bcd6 b2fb4ce 335bcd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 |
from typing import Union
from pydantic import BaseModel, field_validator
import numpy as np
import json
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from PIL import Image
import base64
from io import BytesIO
import io
def encode_image(image: np.ndarray) -> str:
"""Encodes a NumPy array image into a base64 JPEG string.
Args:
image: A NumPy array representing the image.
Returns:
A base64 encoded string prefixed with 'data:image/jpeg;base64,'.
"""
pil_image = Image.fromarray(image)
buffer = BytesIO()
pil_image.save(buffer, format='jpeg')
return f"data:image/jpeg;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
def decode_image(base64_str: str) -> np.ndarray:
"""Decodes a base64 encoded image string into a NumPy array.
Assumes the base64 string represents a valid image format (e.g., JPEG, PNG).
Args:
base64_str: The base64 encoded image string (may include prefix).
Returns:
A NumPy array representing the decoded image.
"""
# Remove the prefix if it exists
if ',' in base64_str:
base64_str = base64_str.split(',', 1)[1]
# Decode the base64 string
image_data = base64.b64decode(base64_str)
# Convert the image data to a PIL Image
image = Image.open(io.BytesIO(image_data))
# Convert the PIL Image to a NumPy array
numpy_image = np.array(image)
return numpy_image
class Finding(BaseModel):
"""Represents a detected finding in an image, including its label,
description, explanation, bounding box coordinates, and severity level.
"""
label: str
description: str
explanation: str
bounding_box: tuple[int, int, int, int]
severity: int
@field_validator("bounding_box")
@classmethod
def validate_bounding_box(cls, value: tuple[int, int, int, int]):
"""Validates that the bounding box coordinates are logically consistent."""
if len(value) != 4:
raise ValueError("Bounding box must be a tuple of 4 integers")
if value[0] >= value[2]:
raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)")
if value[1] >= value[3]:
raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)")
return value
class PartialFinding(BaseModel):
"""A partial version of Finding, where only label and bounding_box are required. Other fields are optional."""
label: str
bounding_box: tuple[int, int, int, int]
description: str | None = None
explanation: str | None = None
severity: int | None = None
@field_validator("bounding_box")
@classmethod
def validate_bounding_box(cls, value: tuple[int, int, int, int]):
"""Validates that the bounding box coordinates are logically consistent."""
if len(value) != 4:
raise ValueError("Bounding box must be a tuple of 4 integers")
if value[0] >= value[2]:
raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)")
if value[1] >= value[3]:
raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)")
return value
class BoundingBox(BaseModel):
"""Represents a bounding box with a label and explicit min/max coordinates. Assumess that the top left corner is the origin"""
label: str
x_min: int
y_min: int
x_max: int
y_max: int
@staticmethod
def from_finding(finding: Union[Finding, PartialFinding]) -> 'BoundingBox':
"""Creates a BoundingBox instance from a Finding instance."""
return BoundingBox(label=finding.label, x_min=finding.bounding_box[0], y_min=finding.bounding_box[1], x_max=finding.bounding_box[2], y_max=finding.bounding_box[3])
@staticmethod
def from_array(label: str, box: list[int]) -> 'BoundingBox':
"""Creates a BoundingBox instance from a label and a list of coordinates."""
return BoundingBox(label=label, x_min=box[0], y_min=box[1], x_max=box[2], y_max=box[3])
def to_tuple(self) -> tuple[int, int, int, int]:
"""Converts the BoundingBox instance to a tuple of coordinates."""
return (self.x_min, self.y_min, self.x_max, self.y_max)
def parse_json_response(out: str) -> list[dict]:
"""Extracts and parses JSON content from a string.
Handles responses potentially wrapped in <output> tags or markdown code blocks (```json).
Args:
out: The input string potentially containing JSON.
Returns:
The parsed JSON object (list or dictionary).
Raises:
ValueError: If no valid JSON content is found.
"""
start_prefix = "<output>"
end_postfix = "</output>"
start_index = out.find(start_prefix)
end_index = out.rfind(end_postfix)
if start_index == -1:
# try to load by finding ```json ``` markers
start_index = out.rfind("```json")
end_index = out.rfind("```")
if start_index == -1 or end_index == -1:
raise ValueError("No JSON found in response")
start_index += len("```json")
fixed = out[start_index:end_index]
print(f"fixed: {fixed}")
return json.loads(fixed)
start_index += len(start_prefix)
fixed = out[start_index:end_index]
fixed = fixed.strip()
if fixed.startswith("```json"):
start_index = fixed.find("[")
end_index = fixed.rfind("]")
fixed = fixed[start_index:end_index + 1]
return json.loads(fixed)
def parse_into_models(findings: list[dict], strict=True) -> Union[list[Finding], list[PartialFinding]]:
"""Parses and validates a list of dictionaries into a list of Finding models.
Args:
findings: A list of dictionaries, each representing a finding.
Returns:
A list of validated Finding model instances.
"""
if not strict:
return [PartialFinding.model_validate(box) for box in findings]
return [Finding.model_validate(box) for box in findings]
def parse_all_safe(out: str) -> list[Finding] | None:
"""Safely parses a string potentially containing JSON findings into Finding models.
Combines `parse_json_response` and `parse_into_models`, returning None on any parsing error.
Args:
out: The input string.
Returns:
A list of Finding models if parsing is successful, otherwise None.
"""
try:
return parse_into_models(parse_json_response(out))
except Exception:
return None
def clamp(num: int | float, min_num: int | float = 0, max_num: int | float = 255) -> int | float:
"""Clamps a number within a specified range [min_num, max_num]."""
return max(min_num, min(num, max_num))
def enlarge_boxes(image_shape: tuple[int, int], findings: list[Finding], factor: float = 1.1) -> list[Finding]:
"""Enlarges the bounding boxes of findings by a given factor, clamping to image boundaries.
Args:
image_shape: A tuple (height, width) representing the image dimensions.
findings: A list of Finding objects.
factor: The factor by which to enlarge the boxes (e.g., 1.1 for 10% larger).
Returns:
A new list of Finding objects with adjusted bounding boxes.
"""
adjusted = []
img_height, img_width = image_shape
for box in findings:
x_min_orig, y_min_orig, x_max_orig, y_max_orig = box.bounding_box
x_width = x_max_orig - x_min_orig
y_width = y_max_orig - y_min_orig
# Calculate the amount to adjust on each side
x_adjust = (x_width * (factor - 1)) / 2
y_adjust = (y_width * (factor - 1)) / 2
# Calculate new coordinates and clamp them
x_min = clamp(x_min_orig - x_adjust, 0, img_width)
y_min = clamp(y_min_orig - y_adjust, 0, img_height)
x_max = clamp(x_max_orig + x_adjust, 0, img_width)
y_max = clamp(y_max_orig + y_adjust, 0, img_height)
# Ensure coordinates remain valid integers if they were originally
adjusted_bbox = (int(round(x_min)), int(round(y_min)), int(round(x_max)), int(round(y_max)))
# Validate adjusted box before creating new Finding
try:
Finding.validate_bounding_box(adjusted_bbox)
adjusted.append(box.model_copy(update={'bounding_box': adjusted_bbox}))
except ValueError:
# If enlarging makes the box invalid (e.g., min >= max), keep the original
adjusted.append(box) # Or handle the error differently if needed
return adjusted
def change_box_format(shape: tuple[int, int, int], box: tuple[int, int, int, int]) -> tuple[int, int, int, int]:
"""Normalizes bounding box coordinates from a 1000x1000 grid to the image dimensions.
This is only for gemini based models, as they returns coordinates normalized between 0-1000
Qwen based models don't need this
Assumes the input box coordinates are relative to a 1000x1000 grid.
Args:
shape: The shape of the target image (height, width, channels).
box: The bounding box tuple (x_min, y_min, x_max, y_max) in 1000x1000 coordinates.
Returns:
A tuple of normalized bounding box coordinates (x_min, y_min, x_max, y_max)
relative to the image dimensions.
"""
y_height, x_width, _ = shape
# Normalize coordinates from 1000x1000 grid to image dimensions
x_min = int((box[0] / 1000.0) * x_width)
y_min = int((box[1] / 1000.0) * y_height)
x_max = int((box[2] / 1000.0) * x_width)
y_max = int((box[3] / 1000.0) * y_height)
return (x_min, y_min, x_max, y_max)
def normalize_findings_boxes(shape: tuple[int, int, int], findings: list[Finding]) -> list[Finding]:
"""Normalizes the bounding boxes of all findings in a list.
This is only for gemini based models, as they returns coordinates normalized between 0-1000
Qwen based models don't need this
Modifies the findings list in-place.
Args:
shape: The shape of the target image (height, width, channels).
findings: A list of Finding objects whose bounding boxes need normalization.
Returns:
The list of Finding objects with normalized bounding boxes (modified in-place).
"""
for finding in findings:
# Ensure the bounding box is a tuple before passing
current_box = tuple(finding.bounding_box)
finding.bounding_box = change_box_format(shape, current_box) # type: ignore
return findings
def visualize_boxes(image, findings):
# Create a figure and axis
fig, ax = plt.subplots(1)
ax.imshow(image)
# Define a list of colors for the boxes
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
for i, finding in enumerate(findings):
[x_min, y_min, x_max, y_max] = finding.bounding_box
# Select a color for the current box
color = colors[i % len(colors)]
rect = patches.Rectangle((x_min, y_min),
x_max - x_min,
y_max - y_min,
linewidth=2, edgecolor=color, facecolor='none')
ax.add_patch(rect)
# Print the whole finding and the color of its box
print(f"Finding {i+1} (Color: {color}):")
if (len(findings) == 0):
print("No findings")
# Set x-axis ticks every 2 units
#plt.xticks(np.arange(0, image.shape[1], 50)) # Start, Stop, Step
#plt.yticks(np.arange(0, image.shape[0], 50)) # Start, Stop, Step
plt.show()
def visualize_boxes_annotated(image: np.ndarray | Image.Image, boxes: list[BoundingBox]) -> np.ndarray:
"""Draws bounding boxes with labels on an image and returns the annotated image as a NumPy array.
Args:
image: The input image (NumPy array or PIL Image).
boxes: A list of BoundingBox objects with coordinates relative to the image.
Returns:
A NumPy array representing the image with annotated bounding boxes.
"""
if not isinstance(image, np.ndarray):
image = np.array(image)
dpi = 300
margin_in_inches = 120 / dpi # 50 pixels
# all this for a tight layout without heaps of margin
fig = plt.figure(figsize=(image.shape[1] / dpi + 2 * margin_in_inches,
image.shape[0] / dpi + 2 * margin_in_inches),
dpi=dpi)
ax = fig.add_axes([margin_in_inches / (image.shape[1] / dpi + 2 * margin_in_inches),
margin_in_inches / (image.shape[0] / dpi + 2 * margin_in_inches),
image.shape[1] / dpi / (image.shape[1] / dpi + 2 * margin_in_inches),
image.shape[0] / dpi / (image.shape[0] / dpi + 2 * margin_in_inches)])
ax.imshow(image)
# Set x-axis and y-axis ticks every 50 units
ax.set_xticks(np.arange(0, image.shape[1], 50))
ax.set_yticks(np.arange(0, image.shape[0], 50))
# Make tick labels smaller
ax.tick_params(axis='both', which='both', labelsize=4)
# Define a list of colors for the boxes
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
for i, box in enumerate(boxes):
x_min = box.x_min
y_min = box.y_min
x_max = box.x_max
y_max = box.y_max
label = box.label
# Select a color for the current box
color = colors[i % len(colors)]
rect = patches.Rectangle((x_min, y_min),
x_max - x_min,
y_max - y_min,
linewidth=1, edgecolor=color, facecolor='none')
ax.add_patch(rect)
# Add label text above the box
ax.text(x_min, y_min-5, label, color=color, fontsize=4,
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
# Instead of displaying, save to numpy array
fig.canvas.draw()
data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))
# Convert RGBA to RGB
data = data[:, :, :3]
plt.close()
return data
|