File size: 14,011 Bytes
5a467ab
335bcd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731975c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335bcd6
 
 
 
 
 
 
 
 
19284aa
335bcd6
 
 
 
 
 
 
 
5ea22b8
 
 
 
335bcd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19284aa
335bcd6
 
 
 
 
 
 
 
5a467ab
19284aa
dae4d1c
335bcd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731975c
335bcd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731975c
 
 
 
335bcd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731975c
335bcd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c790e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335bcd6
c790e67
 
c45a224
 
 
 
c790e67
335bcd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2fb4ce
335bcd6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
from typing import Union
from pydantic import BaseModel, field_validator
import numpy as np
import json
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from PIL import Image
import base64
from io import BytesIO
import io

def encode_image(image: np.ndarray) -> str:
    """Encodes a NumPy array image into a base64 JPEG string.

    Args:
        image: A NumPy array representing the image.

    Returns:
        A base64 encoded string prefixed with 'data:image/jpeg;base64,'.
    """
    pil_image = Image.fromarray(image)
    buffer = BytesIO()
    pil_image.save(buffer, format='jpeg')
    return f"data:image/jpeg;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"

def decode_image(base64_str: str) -> np.ndarray:
    """Decodes a base64 encoded image string into a NumPy array.

    Assumes the base64 string represents a valid image format (e.g., JPEG, PNG).

    Args:
        base64_str: The base64 encoded image string (may include prefix).

    Returns:
        A NumPy array representing the decoded image.
    """
    # Remove the prefix if it exists
    if ',' in base64_str:
        base64_str = base64_str.split(',', 1)[1]

    # Decode the base64 string
    image_data = base64.b64decode(base64_str)

    # Convert the image data to a PIL Image
    image = Image.open(io.BytesIO(image_data))

    # Convert the PIL Image to a NumPy array
    numpy_image = np.array(image)

    return numpy_image

class Finding(BaseModel):
    """Represents a detected finding in an image, including its label,
    description, explanation, bounding box coordinates, and severity level.
    """
    label: str
    description: str
    explanation: str
    bounding_box: tuple[int, int, int, int]
    severity: int

    @field_validator("bounding_box")
    @classmethod
    def validate_bounding_box(cls, value: tuple[int, int, int, int]):
        """Validates that the bounding box coordinates are logically consistent."""
        if len(value) != 4:
            raise ValueError("Bounding box must be a tuple of 4 integers")
        if value[0] >= value[2]:
            raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)")
        if value[1] >= value[3]:
            raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)")
        return value

class PartialFinding(BaseModel):
    """A partial version of Finding, where only label and bounding_box are required. Other fields are optional."""
    label: str
    bounding_box: tuple[int, int, int, int]
    description: str | None = None
    explanation: str | None = None
    severity: int | None = None

    @field_validator("bounding_box")
    @classmethod
    def validate_bounding_box(cls, value: tuple[int, int, int, int]):
        """Validates that the bounding box coordinates are logically consistent."""
        if len(value) != 4:
            raise ValueError("Bounding box must be a tuple of 4 integers")
        if value[0] >= value[2]:
            raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)")
        if value[1] >= value[3]:
            raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)")
        return value

class BoundingBox(BaseModel):
    """Represents a bounding box with a label and explicit min/max coordinates. Assumess that the top left corner is the origin"""
    label: str
    x_min: int
    y_min: int
    x_max: int
    y_max: int

    @staticmethod
    def from_finding(finding: Union[Finding, PartialFinding]) -> 'BoundingBox':
        """Creates a BoundingBox instance from a Finding instance."""
        return BoundingBox(label=finding.label, x_min=finding.bounding_box[0], y_min=finding.bounding_box[1], x_max=finding.bounding_box[2], y_max=finding.bounding_box[3])

    @staticmethod
    def from_array(label: str, box: list[int]) -> 'BoundingBox':
        """Creates a BoundingBox instance from a label and a list of coordinates."""
        return BoundingBox(label=label, x_min=box[0], y_min=box[1], x_max=box[2], y_max=box[3])

    def to_tuple(self) -> tuple[int, int, int, int]:
        """Converts the BoundingBox instance to a tuple of coordinates."""
        return (self.x_min, self.y_min, self.x_max, self.y_max)

def parse_json_response(out: str) -> list[dict]:
    """Extracts and parses JSON content from a string.

    Handles responses potentially wrapped in <output> tags or markdown code blocks (```json).

    Args:
        out: The input string potentially containing JSON.

    Returns:
        The parsed JSON object (list or dictionary).

    Raises:
        ValueError: If no valid JSON content is found.
    """
    start_prefix = "<output>"
    end_postfix = "</output>"
    start_index = out.find(start_prefix)
    end_index = out.rfind(end_postfix)

    if start_index == -1:
        # try to load by finding ```json ``` markers
        start_index = out.rfind("```json")
        end_index = out.rfind("```")
        if start_index == -1 or end_index == -1:
            raise ValueError("No JSON found in response")
        start_index += len("```json")
        fixed = out[start_index:end_index]
        print(f"fixed: {fixed}")
        return json.loads(fixed)

    start_index += len(start_prefix)
    fixed = out[start_index:end_index]
    fixed = fixed.strip()
    if fixed.startswith("```json"):
        start_index = fixed.find("[")
        end_index = fixed.rfind("]")

        fixed = fixed[start_index:end_index + 1]
    return json.loads(fixed)


def parse_into_models(findings: list[dict], strict=True) -> Union[list[Finding], list[PartialFinding]]:
    """Parses and validates a list of dictionaries into a list of Finding models.

    Args:
        findings: A list of dictionaries, each representing a finding.

    Returns:
        A list of validated Finding model instances.
    """
    if not strict:
        return [PartialFinding.model_validate(box) for box in findings]
    return [Finding.model_validate(box) for box in findings]


def parse_all_safe(out: str) -> list[Finding] | None:
    """Safely parses a string potentially containing JSON findings into Finding models.

    Combines `parse_json_response` and `parse_into_models`, returning None on any parsing error.

    Args:
        out: The input string.

    Returns:
        A list of Finding models if parsing is successful, otherwise None.
    """
    try:
        return parse_into_models(parse_json_response(out))
    except Exception:
        return None


def clamp(num: int | float, min_num: int | float = 0, max_num: int | float = 255) -> int | float:
    """Clamps a number within a specified range [min_num, max_num]."""
    return max(min_num, min(num, max_num))

def enlarge_boxes(image_shape: tuple[int, int], findings: list[Finding], factor: float = 1.1) -> list[Finding]:
    """Enlarges the bounding boxes of findings by a given factor, clamping to image boundaries.

    Args:
        image_shape: A tuple (height, width) representing the image dimensions.
        findings: A list of Finding objects.
        factor: The factor by which to enlarge the boxes (e.g., 1.1 for 10% larger).

    Returns:
        A new list of Finding objects with adjusted bounding boxes.
    """
    adjusted = []
    img_height, img_width = image_shape
    for box in findings:
        x_min_orig, y_min_orig, x_max_orig, y_max_orig = box.bounding_box
        x_width = x_max_orig - x_min_orig
        y_width = y_max_orig - y_min_orig

        # Calculate the amount to adjust on each side
        x_adjust = (x_width * (factor - 1)) / 2
        y_adjust = (y_width * (factor - 1)) / 2

        # Calculate new coordinates and clamp them
        x_min = clamp(x_min_orig - x_adjust, 0, img_width)
        y_min = clamp(y_min_orig - y_adjust, 0, img_height)
        x_max = clamp(x_max_orig + x_adjust, 0, img_width)
        y_max = clamp(y_max_orig + y_adjust, 0, img_height)

        # Ensure coordinates remain valid integers if they were originally
        adjusted_bbox = (int(round(x_min)), int(round(y_min)), int(round(x_max)), int(round(y_max)))

        # Validate adjusted box before creating new Finding
        try:
            Finding.validate_bounding_box(adjusted_bbox)
            adjusted.append(box.model_copy(update={'bounding_box': adjusted_bbox}))
        except ValueError:
            # If enlarging makes the box invalid (e.g., min >= max), keep the original
            adjusted.append(box) # Or handle the error differently if needed

    return adjusted

def change_box_format(shape: tuple[int, int, int], box: tuple[int, int, int, int]) -> tuple[int, int, int, int]:
    """Normalizes bounding box coordinates from a 1000x1000 grid to the image dimensions.
    This is only for gemini based models, as they returns coordinates normalized between 0-1000
    Qwen based models don't need this
    Assumes the input box coordinates are relative to a 1000x1000 grid.

    Args:
        shape: The shape of the target image (height, width, channels).
        box: The bounding box tuple (x_min, y_min, x_max, y_max) in 1000x1000 coordinates.

    Returns:
        A tuple of normalized bounding box coordinates (x_min, y_min, x_max, y_max)
        relative to the image dimensions.
    """
    y_height, x_width, _ = shape
    # Normalize coordinates from 1000x1000 grid to image dimensions
    x_min = int((box[0] / 1000.0) * x_width)
    y_min = int((box[1] / 1000.0) * y_height)
    x_max = int((box[2] / 1000.0) * x_width)
    y_max = int((box[3] / 1000.0) * y_height)

    return (x_min, y_min, x_max, y_max)

def normalize_findings_boxes(shape: tuple[int, int, int], findings: list[Finding]) -> list[Finding]:
    """Normalizes the bounding boxes of all findings in a list.
    This is only for gemini based models, as they returns coordinates normalized between 0-1000
    Qwen based models don't need this

    Modifies the findings list in-place.

    Args:
        shape: The shape of the target image (height, width, channels).
        findings: A list of Finding objects whose bounding boxes need normalization.

    Returns:
        The list of Finding objects with normalized bounding boxes (modified in-place).
    """
    for finding in findings:
        # Ensure the bounding box is a tuple before passing
        current_box = tuple(finding.bounding_box)
        finding.bounding_box = change_box_format(shape, current_box) # type: ignore
    return findings

def visualize_boxes(image, findings):
    # Create a figure and axis
    fig, ax = plt.subplots(1)
    ax.imshow(image)

    # Define a list of colors for the boxes
    colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']

    for i, finding in enumerate(findings):
        [x_min, y_min, x_max, y_max] = finding.bounding_box

        # Select a color for the current box
        color = colors[i % len(colors)]

        rect = patches.Rectangle((x_min, y_min),
                                 x_max - x_min,
                                 y_max - y_min,
                                 linewidth=2, edgecolor=color, facecolor='none')

        ax.add_patch(rect)

        # Print the whole finding and the color of its box
        print(f"Finding {i+1} (Color: {color}):")
    if (len(findings) == 0):
        print("No findings")
    # Set x-axis ticks every 2 units
    #plt.xticks(np.arange(0, image.shape[1], 50))  # Start, Stop, Step
    #plt.yticks(np.arange(0, image.shape[0], 50))  # Start, Stop, Step

    plt.show()

def visualize_boxes_annotated(image: np.ndarray | Image.Image, boxes: list[BoundingBox]) -> np.ndarray:
    """Draws bounding boxes with labels on an image and returns the annotated image as a NumPy array.

    Args:
        image: The input image (NumPy array or PIL Image).
        boxes: A list of BoundingBox objects with coordinates relative to the image.

    Returns:
        A NumPy array representing the image with annotated bounding boxes.
    """
    if not isinstance(image, np.ndarray):
        image = np.array(image)

    dpi = 300
    margin_in_inches = 120 / dpi  # 50 pixels


    # all this for a tight layout without heaps of margin
    fig = plt.figure(figsize=(image.shape[1] / dpi + 2 * margin_in_inches,
                            image.shape[0] / dpi + 2 * margin_in_inches),
                    dpi=dpi)

    ax = fig.add_axes([margin_in_inches / (image.shape[1] / dpi + 2 * margin_in_inches),
                    margin_in_inches / (image.shape[0] / dpi + 2 * margin_in_inches),
                    image.shape[1] / dpi / (image.shape[1] / dpi + 2 * margin_in_inches),
                    image.shape[0] / dpi / (image.shape[0] / dpi + 2 * margin_in_inches)])

    ax.imshow(image)

    # Set x-axis and y-axis ticks every 50 units
    ax.set_xticks(np.arange(0, image.shape[1], 50))
    ax.set_yticks(np.arange(0, image.shape[0], 50))
    # Make tick labels smaller
    ax.tick_params(axis='both', which='both', labelsize=4)


    # Define a list of colors for the boxes
    colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']

    for i, box in enumerate(boxes):
        x_min = box.x_min
        y_min = box.y_min
        x_max = box.x_max
        y_max = box.y_max
        label = box.label

        # Select a color for the current box
        color = colors[i % len(colors)]

        rect = patches.Rectangle((x_min, y_min),
                                 x_max - x_min,
                                 y_max - y_min,
                                 linewidth=1, edgecolor=color, facecolor='none')

        ax.add_patch(rect)

        # Add label text above the box
        ax.text(x_min, y_min-5, label, color=color, fontsize=4,
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))

    # Instead of displaying, save to numpy array
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))
    # Convert RGBA to RGB
    data = data[:, :, :3]
    plt.close()
    return data