from pydantic import BaseModel, field_validator import numpy as np import json import matplotlib.patches as patches import matplotlib.pyplot as plt from PIL import Image import base64 from io import BytesIO import io def encode_image(image: np.ndarray) -> str: """Encodes a NumPy array image into a base64 JPEG string. Args: image: A NumPy array representing the image. Returns: A base64 encoded string prefixed with 'data:image/jpeg;base64,'. """ pil_image = Image.fromarray(image) buffer = BytesIO() pil_image.save(buffer, format='jpeg') return f"data:image/jpeg;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" def decode_image(base64_str: str) -> np.ndarray: """Decodes a base64 encoded image string into a NumPy array. Assumes the base64 string represents a valid image format (e.g., JPEG, PNG). Args: base64_str: The base64 encoded image string (may include prefix). Returns: A NumPy array representing the decoded image. """ # Remove the prefix if it exists if ',' in base64_str: base64_str = base64_str.split(',', 1)[1] # Decode the base64 string image_data = base64.b64decode(base64_str) # Convert the image data to a PIL Image image = Image.open(io.BytesIO(image_data)) # Convert the PIL Image to a NumPy array numpy_image = np.array(image) return numpy_image class Finding(BaseModel): """Represents a detected finding in an image, including its label, description, explanation, bounding box coordinates, and severity level. """ label: str description: str explanation: str bounding_box: tuple[int, int, int, int] severity: int @field_validator("bounding_box") @classmethod def validate_bounding_box(cls, value: tuple[int, int, int, int]): """Validates that the bounding box coordinates are logically consistent.""" if len(value) != 4: raise ValueError("Bounding box must be a tuple of 4 integers") if value[0] >= value[2]: raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)") if value[1] >= value[3]: raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)") return value class BoundingBox(BaseModel): """Represents a bounding box with a label and explicit min/max coordinates. Assumess that the top left corner is the origin""" label: str x_min: int y_min: int x_max: int y_max: int @staticmethod def from_finding(finding: Finding) -> 'BoundingBox': """Creates a BoundingBox instance from a Finding instance.""" return BoundingBox(label=finding.label, x_min=finding.bounding_box[0], y_min=finding.bounding_box[1], x_max=finding.bounding_box[2], y_max=finding.bounding_box[3]) @staticmethod def from_array(label: str, box: list[int]) -> 'BoundingBox': """Creates a BoundingBox instance from a label and a list of coordinates.""" return BoundingBox(label=label, x_min=box[0], y_min=box[1], x_max=box[2], y_max=box[3]) def parse_json_response(out: str) -> list[dict]: """Extracts and parses JSON content from a string. Handles responses potentially wrapped in tags or markdown code blocks (```json). Args: out: The input string potentially containing JSON. Returns: The parsed JSON object (list or dictionary). Raises: ValueError: If no valid JSON content is found. """ start_prefix = "" end_postfix = "" start_index = out.find(start_prefix) end_index = out.rfind(end_postfix) if start_index == -1: # try to load by finding ```json ``` markers start_index = out.rfind("```json") end_index = out.rfind("```") if start_index == -1 or end_index == -1: raise ValueError("No JSON found in response") start_index += len("```json") fixed = out[start_index:end_index] print(f"fixed: {fixed}") return json.loads(fixed) start_index += len(start_prefix) fixed = out[start_index:end_index] fixed = fixed.strip() if fixed.startswith("```json"): start_index = fixed.find("[") end_index = fixed.rfind("]") fixed = fixed[start_index:end_index + 1] return json.loads(fixed) def parse_into_models(findings: list[dict]) -> list[Finding]: """Parses and validates a list of dictionaries into a list of Finding models. Args: findings: A list of dictionaries, each representing a finding. Returns: A list of validated Finding model instances. """ return [Finding.model_validate(box) for box in findings] def parse_all_safe(out: str) -> list[Finding] | None: """Safely parses a string potentially containing JSON findings into Finding models. Combines `parse_json_response` and `parse_into_models`, returning None on any parsing error. Args: out: The input string. Returns: A list of Finding models if parsing is successful, otherwise None. """ try: return parse_into_models(parse_json_response(out)) except Exception: return None def clamp(num: int | float, min_num: int | float = 0, max_num: int | float = 255) -> int | float: """Clamps a number within a specified range [min_num, max_num].""" return max(min_num, min(num, max_num)) def enlarge_boxes(image_shape: tuple[int, int], findings: list[Finding], factor: float = 1.1) -> list[Finding]: """Enlarges the bounding boxes of findings by a given factor, clamping to image boundaries. Args: image_shape: A tuple (height, width) representing the image dimensions. findings: A list of Finding objects. factor: The factor by which to enlarge the boxes (e.g., 1.1 for 10% larger). Returns: A new list of Finding objects with adjusted bounding boxes. """ adjusted = [] img_height, img_width = image_shape for box in findings: x_min_orig, y_min_orig, x_max_orig, y_max_orig = box.bounding_box x_width = x_max_orig - x_min_orig y_width = y_max_orig - y_min_orig # Calculate the amount to adjust on each side x_adjust = (x_width * (factor - 1)) / 2 y_adjust = (y_width * (factor - 1)) / 2 # Calculate new coordinates and clamp them x_min = clamp(x_min_orig - x_adjust, 0, img_width) y_min = clamp(y_min_orig - y_adjust, 0, img_height) x_max = clamp(x_max_orig + x_adjust, 0, img_width) y_max = clamp(y_max_orig + y_adjust, 0, img_height) # Ensure coordinates remain valid integers if they were originally adjusted_bbox = (int(round(x_min)), int(round(y_min)), int(round(x_max)), int(round(y_max))) # Validate adjusted box before creating new Finding try: Finding.validate_bounding_box(adjusted_bbox) adjusted.append(box.model_copy(update={'bounding_box': adjusted_bbox})) except ValueError: # If enlarging makes the box invalid (e.g., min >= max), keep the original adjusted.append(box) # Or handle the error differently if needed return adjusted def change_box_format(shape: tuple[int, int, int], box: tuple[int, int, int, int]) -> tuple[float, float, float, float]: """Normalizes bounding box coordinates from a 1000x1000 grid to the image dimensions. This is only for gemini based models, as they returns coordinates normalized between 0-1000 Qwen based models don't need this Assumes the input box coordinates are relative to a 1000x1000 grid. Args: shape: The shape of the target image (height, width, channels). box: The bounding box tuple (x_min, y_min, x_max, y_max) in 1000x1000 coordinates. Returns: A tuple of normalized bounding box coordinates (x_min, y_min, x_max, y_max) relative to the image dimensions. """ y_height, x_width, _ = shape # Normalize coordinates from 1000x1000 grid to image dimensions x_min = (box[0] / 1000.0) * x_width y_min = (box[1] / 1000.0) * y_height x_max = (box[2] / 1000.0) * x_width y_max = (box[3] / 1000.0) * y_height return (x_min, y_min, x_max, y_max) def normalize_findings_boxes(shape: tuple[int, int, int], findings: list[Finding]) -> list[Finding]: """Normalizes the bounding boxes of all findings in a list. This is only for gemini based models, as they returns coordinates normalized between 0-1000 Qwen based models don't need this Modifies the findings list in-place. Args: shape: The shape of the target image (height, width, channels). findings: A list of Finding objects whose bounding boxes need normalization. Returns: The list of Finding objects with normalized bounding boxes (modified in-place). """ for finding in findings: # Ensure the bounding box is a tuple before passing current_box = tuple(finding.bounding_box) finding.bounding_box = change_box_format(shape, current_box) return findings def change_box_format(shape, box): y_width, x_width, _ = shape # so apparently the bounding box always refers to a 1000x1000 grid # so we need to normalize # i assume that it has to do with the way their image embeddings work x_min = (box[0] / 1000) * x_width y_min = (box[1] / 1000) * y_width x_max = (box[2] / 1000) * x_width y_max = (box[3] / 1000) * y_width return [x_min, y_min, x_max, y_max] def normalize_findings_boxes(shape, findings): for finding in findings: finding.bounding_box = change_box_format(shape, finding.bounding_box) return findings def visualize_boxes(image, findings): # Create a figure and axis fig, ax = plt.subplots(1) ax.imshow(image) # Define a list of colors for the boxes colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k'] for i, finding in enumerate(findings): [x_min, y_min, x_max, y_max] = finding.bounding_box # Select a color for the current box color = colors[i % len(colors)] rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2, edgecolor=color, facecolor='none') ax.add_patch(rect) # Print the whole finding and the color of its box print(f"Finding {i+1} (Color: {color}):") if (len(findings) == 0): print("No findings") # Set x-axis ticks every 2 units #plt.xticks(np.arange(0, image.shape[1], 50)) # Start, Stop, Step #plt.yticks(np.arange(0, image.shape[0], 50)) # Start, Stop, Step plt.show() def visualize_boxes_annotated(image: np.ndarray | Image.Image, boxes: list[BoundingBox]) -> np.ndarray: """Draws bounding boxes with labels on an image and returns the annotated image as a NumPy array. Args: image: The input image (NumPy array or PIL Image). boxes: A list of BoundingBox objects with coordinates relative to the image. Returns: A NumPy array representing the image with annotated bounding boxes. """ if not isinstance(image, np.ndarray): image = np.array(image) # Create a figure and axis with high DPI fig = plt.figure(dpi=300) ax = plt.subplot(111) ax.imshow(image) ax.set_axis_off() plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) # Define a list of colors for the boxes colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k'] for i, box in enumerate(boxes): x_min = box.x_min y_min = box.y_min x_max = box.x_max y_max = box.y_max label = box.label # Select a color for the current box color = colors[i % len(colors)] rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=1, edgecolor=color, facecolor='none') ax.add_patch(rect) # Add label text above the box ax.text(x_min, y_min-5, label, color=color, fontsize=10, bbox=dict(facecolor='white', alpha=0.7, edgecolor='none')) # Instead of displaying, save to numpy array fig.canvas.draw() data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,)) # Convert RGBA to RGB data = data[:, :, :3] plt.close() return data