Spaces:
Running
on
Zero
Running
on
Zero
from typing import Union | |
from pydantic import BaseModel, field_validator | |
import numpy as np | |
import json | |
import matplotlib.patches as patches | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import base64 | |
from io import BytesIO | |
import io | |
def encode_image(image: np.ndarray) -> str: | |
"""Encodes a NumPy array image into a base64 JPEG string. | |
Args: | |
image: A NumPy array representing the image. | |
Returns: | |
A base64 encoded string prefixed with 'data:image/jpeg;base64,'. | |
""" | |
pil_image = Image.fromarray(image) | |
buffer = BytesIO() | |
pil_image.save(buffer, format='jpeg') | |
return f"data:image/jpeg;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" | |
def decode_image(base64_str: str) -> np.ndarray: | |
"""Decodes a base64 encoded image string into a NumPy array. | |
Assumes the base64 string represents a valid image format (e.g., JPEG, PNG). | |
Args: | |
base64_str: The base64 encoded image string (may include prefix). | |
Returns: | |
A NumPy array representing the decoded image. | |
""" | |
# Remove the prefix if it exists | |
if ',' in base64_str: | |
base64_str = base64_str.split(',', 1)[1] | |
# Decode the base64 string | |
image_data = base64.b64decode(base64_str) | |
# Convert the image data to a PIL Image | |
image = Image.open(io.BytesIO(image_data)) | |
# Convert the PIL Image to a NumPy array | |
numpy_image = np.array(image) | |
return numpy_image | |
class Finding(BaseModel): | |
"""Represents a detected finding in an image, including its label, | |
description, explanation, bounding box coordinates, and severity level. | |
""" | |
label: str | |
description: str | |
explanation: str | |
bounding_box: tuple[int, int, int, int] | |
severity: int | |
def validate_bounding_box(cls, value: tuple[int, int, int, int]): | |
"""Validates that the bounding box coordinates are logically consistent.""" | |
if len(value) != 4: | |
raise ValueError("Bounding box must be a tuple of 4 integers") | |
if value[0] >= value[2]: | |
raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)") | |
if value[1] >= value[3]: | |
raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)") | |
return value | |
class PartialFinding(BaseModel): | |
"""A partial version of Finding, where only label and bounding_box are required. Other fields are optional.""" | |
label: str | |
bounding_box: tuple[int, int, int, int] | |
description: str | None = None | |
explanation: str | None = None | |
severity: int | None = None | |
def validate_bounding_box(cls, value: tuple[int, int, int, int]): | |
"""Validates that the bounding box coordinates are logically consistent.""" | |
if len(value) != 4: | |
raise ValueError("Bounding box must be a tuple of 4 integers") | |
if value[0] >= value[2]: | |
raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)") | |
if value[1] >= value[3]: | |
raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)") | |
return value | |
class BoundingBox(BaseModel): | |
"""Represents a bounding box with a label and explicit min/max coordinates. Assumess that the top left corner is the origin""" | |
label: str | |
x_min: int | |
y_min: int | |
x_max: int | |
y_max: int | |
def from_finding(finding: Union[Finding, PartialFinding]) -> 'BoundingBox': | |
"""Creates a BoundingBox instance from a Finding instance.""" | |
return BoundingBox(label=finding.label, x_min=finding.bounding_box[0], y_min=finding.bounding_box[1], x_max=finding.bounding_box[2], y_max=finding.bounding_box[3]) | |
def from_array(label: str, box: list[int]) -> 'BoundingBox': | |
"""Creates a BoundingBox instance from a label and a list of coordinates.""" | |
return BoundingBox(label=label, x_min=box[0], y_min=box[1], x_max=box[2], y_max=box[3]) | |
def to_tuple(self) -> tuple[int, int, int, int]: | |
"""Converts the BoundingBox instance to a tuple of coordinates.""" | |
return (self.x_min, self.y_min, self.x_max, self.y_max) | |
def parse_json_response(out: str) -> list[dict]: | |
"""Extracts and parses JSON content from a string. | |
Handles responses potentially wrapped in <output> tags or markdown code blocks (```json). | |
Args: | |
out: The input string potentially containing JSON. | |
Returns: | |
The parsed JSON object (list or dictionary). | |
Raises: | |
ValueError: If no valid JSON content is found. | |
""" | |
start_prefix = "<output>" | |
end_postfix = "</output>" | |
start_index = out.find(start_prefix) | |
end_index = out.rfind(end_postfix) | |
if start_index == -1: | |
# try to load by finding ```json ``` markers | |
start_index = out.rfind("```json") | |
end_index = out.rfind("```") | |
if start_index == -1 or end_index == -1: | |
raise ValueError("No JSON found in response") | |
start_index += len("```json") | |
fixed = out[start_index:end_index] | |
print(f"fixed: {fixed}") | |
return json.loads(fixed) | |
start_index += len(start_prefix) | |
fixed = out[start_index:end_index] | |
fixed = fixed.strip() | |
if fixed.startswith("```json"): | |
start_index = fixed.find("[") | |
end_index = fixed.rfind("]") | |
fixed = fixed[start_index:end_index + 1] | |
return json.loads(fixed) | |
def parse_into_models(findings: list[dict], strict=True) -> Union[list[Finding], list[PartialFinding]]: | |
"""Parses and validates a list of dictionaries into a list of Finding models. | |
Args: | |
findings: A list of dictionaries, each representing a finding. | |
Returns: | |
A list of validated Finding model instances. | |
""" | |
if not strict: | |
return [PartialFinding.model_validate(box) for box in findings] | |
return [Finding.model_validate(box) for box in findings] | |
def parse_all_safe(out: str) -> list[Finding] | None: | |
"""Safely parses a string potentially containing JSON findings into Finding models. | |
Combines `parse_json_response` and `parse_into_models`, returning None on any parsing error. | |
Args: | |
out: The input string. | |
Returns: | |
A list of Finding models if parsing is successful, otherwise None. | |
""" | |
try: | |
return parse_into_models(parse_json_response(out)) | |
except Exception: | |
return None | |
def clamp(num: int | float, min_num: int | float = 0, max_num: int | float = 255) -> int | float: | |
"""Clamps a number within a specified range [min_num, max_num].""" | |
return max(min_num, min(num, max_num)) | |
def enlarge_boxes(image_shape: tuple[int, int], findings: list[Finding], factor: float = 1.1) -> list[Finding]: | |
"""Enlarges the bounding boxes of findings by a given factor, clamping to image boundaries. | |
Args: | |
image_shape: A tuple (height, width) representing the image dimensions. | |
findings: A list of Finding objects. | |
factor: The factor by which to enlarge the boxes (e.g., 1.1 for 10% larger). | |
Returns: | |
A new list of Finding objects with adjusted bounding boxes. | |
""" | |
adjusted = [] | |
img_height, img_width = image_shape | |
for box in findings: | |
x_min_orig, y_min_orig, x_max_orig, y_max_orig = box.bounding_box | |
x_width = x_max_orig - x_min_orig | |
y_width = y_max_orig - y_min_orig | |
# Calculate the amount to adjust on each side | |
x_adjust = (x_width * (factor - 1)) / 2 | |
y_adjust = (y_width * (factor - 1)) / 2 | |
# Calculate new coordinates and clamp them | |
x_min = clamp(x_min_orig - x_adjust, 0, img_width) | |
y_min = clamp(y_min_orig - y_adjust, 0, img_height) | |
x_max = clamp(x_max_orig + x_adjust, 0, img_width) | |
y_max = clamp(y_max_orig + y_adjust, 0, img_height) | |
# Ensure coordinates remain valid integers if they were originally | |
adjusted_bbox = (int(round(x_min)), int(round(y_min)), int(round(x_max)), int(round(y_max))) | |
# Validate adjusted box before creating new Finding | |
try: | |
Finding.validate_bounding_box(adjusted_bbox) | |
adjusted.append(box.model_copy(update={'bounding_box': adjusted_bbox})) | |
except ValueError: | |
# If enlarging makes the box invalid (e.g., min >= max), keep the original | |
adjusted.append(box) # Or handle the error differently if needed | |
return adjusted | |
def change_box_format(shape: tuple[int, int, int], box: tuple[int, int, int, int]) -> tuple[int, int, int, int]: | |
"""Normalizes bounding box coordinates from a 1000x1000 grid to the image dimensions. | |
This is only for gemini based models, as they returns coordinates normalized between 0-1000 | |
Qwen based models don't need this | |
Assumes the input box coordinates are relative to a 1000x1000 grid. | |
Args: | |
shape: The shape of the target image (height, width, channels). | |
box: The bounding box tuple (x_min, y_min, x_max, y_max) in 1000x1000 coordinates. | |
Returns: | |
A tuple of normalized bounding box coordinates (x_min, y_min, x_max, y_max) | |
relative to the image dimensions. | |
""" | |
y_height, x_width, _ = shape | |
# Normalize coordinates from 1000x1000 grid to image dimensions | |
x_min = int((box[0] / 1000.0) * x_width) | |
y_min = int((box[1] / 1000.0) * y_height) | |
x_max = int((box[2] / 1000.0) * x_width) | |
y_max = int((box[3] / 1000.0) * y_height) | |
return (x_min, y_min, x_max, y_max) | |
def normalize_findings_boxes(shape: tuple[int, int, int], findings: list[Finding]) -> list[Finding]: | |
"""Normalizes the bounding boxes of all findings in a list. | |
This is only for gemini based models, as they returns coordinates normalized between 0-1000 | |
Qwen based models don't need this | |
Modifies the findings list in-place. | |
Args: | |
shape: The shape of the target image (height, width, channels). | |
findings: A list of Finding objects whose bounding boxes need normalization. | |
Returns: | |
The list of Finding objects with normalized bounding boxes (modified in-place). | |
""" | |
for finding in findings: | |
# Ensure the bounding box is a tuple before passing | |
current_box = tuple(finding.bounding_box) | |
finding.bounding_box = change_box_format(shape, current_box) # type: ignore | |
return findings | |
def visualize_boxes(image, findings): | |
# Create a figure and axis | |
fig, ax = plt.subplots(1) | |
ax.imshow(image) | |
# Define a list of colors for the boxes | |
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k'] | |
for i, finding in enumerate(findings): | |
[x_min, y_min, x_max, y_max] = finding.bounding_box | |
# Select a color for the current box | |
color = colors[i % len(colors)] | |
rect = patches.Rectangle((x_min, y_min), | |
x_max - x_min, | |
y_max - y_min, | |
linewidth=2, edgecolor=color, facecolor='none') | |
ax.add_patch(rect) | |
# Print the whole finding and the color of its box | |
print(f"Finding {i+1} (Color: {color}):") | |
if (len(findings) == 0): | |
print("No findings") | |
# Set x-axis ticks every 2 units | |
#plt.xticks(np.arange(0, image.shape[1], 50)) # Start, Stop, Step | |
#plt.yticks(np.arange(0, image.shape[0], 50)) # Start, Stop, Step | |
plt.show() | |
def visualize_boxes_annotated(image: np.ndarray | Image.Image, boxes: list[BoundingBox]) -> np.ndarray: | |
"""Draws bounding boxes with labels on an image and returns the annotated image as a NumPy array. | |
Args: | |
image: The input image (NumPy array or PIL Image). | |
boxes: A list of BoundingBox objects with coordinates relative to the image. | |
Returns: | |
A NumPy array representing the image with annotated bounding boxes. | |
""" | |
if not isinstance(image, np.ndarray): | |
image = np.array(image) | |
dpi = 300 | |
margin_in_inches = 120 / dpi # 50 pixels | |
# all this for a tight layout without heaps of margin | |
fig = plt.figure(figsize=(image.shape[1] / dpi + 2 * margin_in_inches, | |
image.shape[0] / dpi + 2 * margin_in_inches), | |
dpi=dpi) | |
ax = fig.add_axes([margin_in_inches / (image.shape[1] / dpi + 2 * margin_in_inches), | |
margin_in_inches / (image.shape[0] / dpi + 2 * margin_in_inches), | |
image.shape[1] / dpi / (image.shape[1] / dpi + 2 * margin_in_inches), | |
image.shape[0] / dpi / (image.shape[0] / dpi + 2 * margin_in_inches)]) | |
ax.imshow(image) | |
# Set x-axis and y-axis ticks every 50 units | |
ax.set_xticks(np.arange(0, image.shape[1], 50)) | |
ax.set_yticks(np.arange(0, image.shape[0], 50)) | |
# Make tick labels smaller | |
ax.tick_params(axis='both', which='both', labelsize=4) | |
# Define a list of colors for the boxes | |
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k'] | |
for i, box in enumerate(boxes): | |
x_min = box.x_min | |
y_min = box.y_min | |
x_max = box.x_max | |
y_max = box.y_max | |
label = box.label | |
# Select a color for the current box | |
color = colors[i % len(colors)] | |
rect = patches.Rectangle((x_min, y_min), | |
x_max - x_min, | |
y_max - y_min, | |
linewidth=1, edgecolor=color, facecolor='none') | |
ax.add_patch(rect) | |
# Add label text above the box | |
ax.text(x_min, y_min-5, label, color=color, fontsize=4, | |
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none')) | |
# Instead of displaying, save to numpy array | |
fig.canvas.draw() | |
data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
# Convert RGBA to RGB | |
data = data[:, :, :3] | |
plt.close() | |
return data | |