imgprivllm / utils.py
hugohabicht01
make it possible to edit system prompt
b2fb4ce
from typing import Union
from pydantic import BaseModel, field_validator
import numpy as np
import json
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from PIL import Image
import base64
from io import BytesIO
import io
def encode_image(image: np.ndarray) -> str:
"""Encodes a NumPy array image into a base64 JPEG string.
Args:
image: A NumPy array representing the image.
Returns:
A base64 encoded string prefixed with 'data:image/jpeg;base64,'.
"""
pil_image = Image.fromarray(image)
buffer = BytesIO()
pil_image.save(buffer, format='jpeg')
return f"data:image/jpeg;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
def decode_image(base64_str: str) -> np.ndarray:
"""Decodes a base64 encoded image string into a NumPy array.
Assumes the base64 string represents a valid image format (e.g., JPEG, PNG).
Args:
base64_str: The base64 encoded image string (may include prefix).
Returns:
A NumPy array representing the decoded image.
"""
# Remove the prefix if it exists
if ',' in base64_str:
base64_str = base64_str.split(',', 1)[1]
# Decode the base64 string
image_data = base64.b64decode(base64_str)
# Convert the image data to a PIL Image
image = Image.open(io.BytesIO(image_data))
# Convert the PIL Image to a NumPy array
numpy_image = np.array(image)
return numpy_image
class Finding(BaseModel):
"""Represents a detected finding in an image, including its label,
description, explanation, bounding box coordinates, and severity level.
"""
label: str
description: str
explanation: str
bounding_box: tuple[int, int, int, int]
severity: int
@field_validator("bounding_box")
@classmethod
def validate_bounding_box(cls, value: tuple[int, int, int, int]):
"""Validates that the bounding box coordinates are logically consistent."""
if len(value) != 4:
raise ValueError("Bounding box must be a tuple of 4 integers")
if value[0] >= value[2]:
raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)")
if value[1] >= value[3]:
raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)")
return value
class PartialFinding(BaseModel):
"""A partial version of Finding, where only label and bounding_box are required. Other fields are optional."""
label: str
bounding_box: tuple[int, int, int, int]
description: str | None = None
explanation: str | None = None
severity: int | None = None
@field_validator("bounding_box")
@classmethod
def validate_bounding_box(cls, value: tuple[int, int, int, int]):
"""Validates that the bounding box coordinates are logically consistent."""
if len(value) != 4:
raise ValueError("Bounding box must be a tuple of 4 integers")
if value[0] >= value[2]:
raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)")
if value[1] >= value[3]:
raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)")
return value
class BoundingBox(BaseModel):
"""Represents a bounding box with a label and explicit min/max coordinates. Assumess that the top left corner is the origin"""
label: str
x_min: int
y_min: int
x_max: int
y_max: int
@staticmethod
def from_finding(finding: Union[Finding, PartialFinding]) -> 'BoundingBox':
"""Creates a BoundingBox instance from a Finding instance."""
return BoundingBox(label=finding.label, x_min=finding.bounding_box[0], y_min=finding.bounding_box[1], x_max=finding.bounding_box[2], y_max=finding.bounding_box[3])
@staticmethod
def from_array(label: str, box: list[int]) -> 'BoundingBox':
"""Creates a BoundingBox instance from a label and a list of coordinates."""
return BoundingBox(label=label, x_min=box[0], y_min=box[1], x_max=box[2], y_max=box[3])
def to_tuple(self) -> tuple[int, int, int, int]:
"""Converts the BoundingBox instance to a tuple of coordinates."""
return (self.x_min, self.y_min, self.x_max, self.y_max)
def parse_json_response(out: str) -> list[dict]:
"""Extracts and parses JSON content from a string.
Handles responses potentially wrapped in <output> tags or markdown code blocks (```json).
Args:
out: The input string potentially containing JSON.
Returns:
The parsed JSON object (list or dictionary).
Raises:
ValueError: If no valid JSON content is found.
"""
start_prefix = "<output>"
end_postfix = "</output>"
start_index = out.find(start_prefix)
end_index = out.rfind(end_postfix)
if start_index == -1:
# try to load by finding ```json ``` markers
start_index = out.rfind("```json")
end_index = out.rfind("```")
if start_index == -1 or end_index == -1:
raise ValueError("No JSON found in response")
start_index += len("```json")
fixed = out[start_index:end_index]
print(f"fixed: {fixed}")
return json.loads(fixed)
start_index += len(start_prefix)
fixed = out[start_index:end_index]
fixed = fixed.strip()
if fixed.startswith("```json"):
start_index = fixed.find("[")
end_index = fixed.rfind("]")
fixed = fixed[start_index:end_index + 1]
return json.loads(fixed)
def parse_into_models(findings: list[dict], strict=True) -> Union[list[Finding], list[PartialFinding]]:
"""Parses and validates a list of dictionaries into a list of Finding models.
Args:
findings: A list of dictionaries, each representing a finding.
Returns:
A list of validated Finding model instances.
"""
if not strict:
return [PartialFinding.model_validate(box) for box in findings]
return [Finding.model_validate(box) for box in findings]
def parse_all_safe(out: str) -> list[Finding] | None:
"""Safely parses a string potentially containing JSON findings into Finding models.
Combines `parse_json_response` and `parse_into_models`, returning None on any parsing error.
Args:
out: The input string.
Returns:
A list of Finding models if parsing is successful, otherwise None.
"""
try:
return parse_into_models(parse_json_response(out))
except Exception:
return None
def clamp(num: int | float, min_num: int | float = 0, max_num: int | float = 255) -> int | float:
"""Clamps a number within a specified range [min_num, max_num]."""
return max(min_num, min(num, max_num))
def enlarge_boxes(image_shape: tuple[int, int], findings: list[Finding], factor: float = 1.1) -> list[Finding]:
"""Enlarges the bounding boxes of findings by a given factor, clamping to image boundaries.
Args:
image_shape: A tuple (height, width) representing the image dimensions.
findings: A list of Finding objects.
factor: The factor by which to enlarge the boxes (e.g., 1.1 for 10% larger).
Returns:
A new list of Finding objects with adjusted bounding boxes.
"""
adjusted = []
img_height, img_width = image_shape
for box in findings:
x_min_orig, y_min_orig, x_max_orig, y_max_orig = box.bounding_box
x_width = x_max_orig - x_min_orig
y_width = y_max_orig - y_min_orig
# Calculate the amount to adjust on each side
x_adjust = (x_width * (factor - 1)) / 2
y_adjust = (y_width * (factor - 1)) / 2
# Calculate new coordinates and clamp them
x_min = clamp(x_min_orig - x_adjust, 0, img_width)
y_min = clamp(y_min_orig - y_adjust, 0, img_height)
x_max = clamp(x_max_orig + x_adjust, 0, img_width)
y_max = clamp(y_max_orig + y_adjust, 0, img_height)
# Ensure coordinates remain valid integers if they were originally
adjusted_bbox = (int(round(x_min)), int(round(y_min)), int(round(x_max)), int(round(y_max)))
# Validate adjusted box before creating new Finding
try:
Finding.validate_bounding_box(adjusted_bbox)
adjusted.append(box.model_copy(update={'bounding_box': adjusted_bbox}))
except ValueError:
# If enlarging makes the box invalid (e.g., min >= max), keep the original
adjusted.append(box) # Or handle the error differently if needed
return adjusted
def change_box_format(shape: tuple[int, int, int], box: tuple[int, int, int, int]) -> tuple[int, int, int, int]:
"""Normalizes bounding box coordinates from a 1000x1000 grid to the image dimensions.
This is only for gemini based models, as they returns coordinates normalized between 0-1000
Qwen based models don't need this
Assumes the input box coordinates are relative to a 1000x1000 grid.
Args:
shape: The shape of the target image (height, width, channels).
box: The bounding box tuple (x_min, y_min, x_max, y_max) in 1000x1000 coordinates.
Returns:
A tuple of normalized bounding box coordinates (x_min, y_min, x_max, y_max)
relative to the image dimensions.
"""
y_height, x_width, _ = shape
# Normalize coordinates from 1000x1000 grid to image dimensions
x_min = int((box[0] / 1000.0) * x_width)
y_min = int((box[1] / 1000.0) * y_height)
x_max = int((box[2] / 1000.0) * x_width)
y_max = int((box[3] / 1000.0) * y_height)
return (x_min, y_min, x_max, y_max)
def normalize_findings_boxes(shape: tuple[int, int, int], findings: list[Finding]) -> list[Finding]:
"""Normalizes the bounding boxes of all findings in a list.
This is only for gemini based models, as they returns coordinates normalized between 0-1000
Qwen based models don't need this
Modifies the findings list in-place.
Args:
shape: The shape of the target image (height, width, channels).
findings: A list of Finding objects whose bounding boxes need normalization.
Returns:
The list of Finding objects with normalized bounding boxes (modified in-place).
"""
for finding in findings:
# Ensure the bounding box is a tuple before passing
current_box = tuple(finding.bounding_box)
finding.bounding_box = change_box_format(shape, current_box) # type: ignore
return findings
def visualize_boxes(image, findings):
# Create a figure and axis
fig, ax = plt.subplots(1)
ax.imshow(image)
# Define a list of colors for the boxes
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
for i, finding in enumerate(findings):
[x_min, y_min, x_max, y_max] = finding.bounding_box
# Select a color for the current box
color = colors[i % len(colors)]
rect = patches.Rectangle((x_min, y_min),
x_max - x_min,
y_max - y_min,
linewidth=2, edgecolor=color, facecolor='none')
ax.add_patch(rect)
# Print the whole finding and the color of its box
print(f"Finding {i+1} (Color: {color}):")
if (len(findings) == 0):
print("No findings")
# Set x-axis ticks every 2 units
#plt.xticks(np.arange(0, image.shape[1], 50)) # Start, Stop, Step
#plt.yticks(np.arange(0, image.shape[0], 50)) # Start, Stop, Step
plt.show()
def visualize_boxes_annotated(image: np.ndarray | Image.Image, boxes: list[BoundingBox]) -> np.ndarray:
"""Draws bounding boxes with labels on an image and returns the annotated image as a NumPy array.
Args:
image: The input image (NumPy array or PIL Image).
boxes: A list of BoundingBox objects with coordinates relative to the image.
Returns:
A NumPy array representing the image with annotated bounding boxes.
"""
if not isinstance(image, np.ndarray):
image = np.array(image)
dpi = 300
margin_in_inches = 120 / dpi # 50 pixels
# all this for a tight layout without heaps of margin
fig = plt.figure(figsize=(image.shape[1] / dpi + 2 * margin_in_inches,
image.shape[0] / dpi + 2 * margin_in_inches),
dpi=dpi)
ax = fig.add_axes([margin_in_inches / (image.shape[1] / dpi + 2 * margin_in_inches),
margin_in_inches / (image.shape[0] / dpi + 2 * margin_in_inches),
image.shape[1] / dpi / (image.shape[1] / dpi + 2 * margin_in_inches),
image.shape[0] / dpi / (image.shape[0] / dpi + 2 * margin_in_inches)])
ax.imshow(image)
# Set x-axis and y-axis ticks every 50 units
ax.set_xticks(np.arange(0, image.shape[1], 50))
ax.set_yticks(np.arange(0, image.shape[0], 50))
# Make tick labels smaller
ax.tick_params(axis='both', which='both', labelsize=4)
# Define a list of colors for the boxes
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
for i, box in enumerate(boxes):
x_min = box.x_min
y_min = box.y_min
x_max = box.x_max
y_max = box.y_max
label = box.label
# Select a color for the current box
color = colors[i % len(colors)]
rect = patches.Rectangle((x_min, y_min),
x_max - x_min,
y_max - y_min,
linewidth=1, edgecolor=color, facecolor='none')
ax.add_patch(rect)
# Add label text above the box
ax.text(x_min, y_min-5, label, color=color, fontsize=4,
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
# Instead of displaying, save to numpy array
fig.canvas.draw()
data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))
# Convert RGBA to RGB
data = data[:, :, :3]
plt.close()
return data