Spaces:
Running
on
Zero
Running
on
Zero
hugohabicht01
commited on
Commit
·
335bcd6
1
Parent(s):
c8cd915
init
Browse files- app.py +297 -0
- blurnonymize.py +300 -0
- requirements.txt +12 -0
- utils.py +351 -0
app.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
from unsloth import FastVisionModel
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import traceback
|
8 |
+
from typing import Any, Optional
|
9 |
+
|
10 |
+
# Import user-provided modules
|
11 |
+
import utils
|
12 |
+
from utils import Finding, BoundingBox # Explicitly import needed classes
|
13 |
+
import blurnonymize
|
14 |
+
|
15 |
+
# --- Constants ---
|
16 |
+
MODEL_NAME = "cborg/qwen2.5VL-3b-privacydetector"
|
17 |
+
MAX_NEW_TOKENS = 2048
|
18 |
+
TEMPERATURE = 1.0
|
19 |
+
MIN_P = 0.1
|
20 |
+
SYSTEM_PROMPT = """You are a helpful assistant for privacy analysis of images. Please always answer in English. Please obey the users instructions and follow the provided format."""
|
21 |
+
DEFAULT_PROMPT = """
|
22 |
+
You are an expert at pixel perfect image analysis and in privacy.
|
23 |
+
First write down your thoughts within a <think> block.
|
24 |
+
Please go through all objects in the image and consider whether they are private data or not.
|
25 |
+
End this with a </think> block.
|
26 |
+
|
27 |
+
After going through everything, output your findings in an <output></output> block as a json list with the following keys:
|
28 |
+
{"label": <|object_ref_start|>str<|object_ref_end|>, "description": str, "explanation": str, "bounding_box": <|box_start|>[x_min, y_min, x_max, y_max]<|box_end|>, "severity": int}
|
29 |
+
|
30 |
+
Some things to remember:
|
31 |
+
|
32 |
+
- private data is only data thats linked to a human person, common examples being a persons face, name, address, license plate
|
33 |
+
- whenever something can be used to identify a unique human person, it is private data
|
34 |
+
- report sensitive data as well, such as a nude person
|
35 |
+
- Severity is a number between 0 and 10, with 0 being not private data and 10 being extremely sensitive private data.
|
36 |
+
- don't report items which dont contain private data in the final output, you may mention them in your thoughts
|
37 |
+
- animals and animal faces are not personal data, so a giraffe or a dog is not private data
|
38 |
+
- you can use whatever format you want within the <think> </think> blocks
|
39 |
+
- only output valid JSON in between the <output> </output> blocks, adhering to the schema provided
|
40 |
+
- output the bounding box always as an array of form [x_min, y_min, x_max, y_max]
|
41 |
+
- private data have a severity greater than 0, so a human face would have severity 6
|
42 |
+
- go through the image step by step and report the private data, its better to be a bit too sensitive than to miss anything
|
43 |
+
- put the bounding boxes around the human's face and not the entire person when reporting people as personal data
|
44 |
+
- Think step by step, take your time.
|
45 |
+
|
46 |
+
Here is the image to analyse, start your analysis directly after:
|
47 |
+
"""
|
48 |
+
|
49 |
+
|
50 |
+
def build_messages(image, history: Optional[list[dict[str, Any]]] = None, prompt: Optional[str] = None):
|
51 |
+
if not prompt:
|
52 |
+
prompt = DEFAULT_PROMPT
|
53 |
+
|
54 |
+
if history:
|
55 |
+
return [
|
56 |
+
*history,
|
57 |
+
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
58 |
+
]
|
59 |
+
|
60 |
+
return [
|
61 |
+
{
|
62 |
+
"role": "system",
|
63 |
+
"content": [
|
64 |
+
{
|
65 |
+
"type": "text",
|
66 |
+
"text": SYSTEM_PROMPT,
|
67 |
+
}
|
68 |
+
],
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"role": "user",
|
72 |
+
"content": [
|
73 |
+
{"type": "text", "text": prompt},
|
74 |
+
{"type": "image", "image": image},
|
75 |
+
],
|
76 |
+
},
|
77 |
+
]
|
78 |
+
|
79 |
+
|
80 |
+
# --- Model Loading ---
|
81 |
+
# Load model using unsloth for 4-bit quantization
|
82 |
+
try:
|
83 |
+
model, tokenizer = FastVisionModel.from_pretrained(
|
84 |
+
model_name=MODEL_NAME,
|
85 |
+
load_in_4bit=True,
|
86 |
+
)
|
87 |
+
FastVisionModel.for_inference(model)
|
88 |
+
model.to("cuda").eval() # Ensure model is on GPU and in eval mode
|
89 |
+
print("Model loaded successfully.")
|
90 |
+
except Exception as e:
|
91 |
+
print(f"Error loading model: {e}")
|
92 |
+
print(traceback.format_exc())
|
93 |
+
# Optionally raise or handle the error to prevent app launch if model fails
|
94 |
+
raise gr.Error(f"Failed to load model {MODEL_NAME}. Check logs. Error: {e}")
|
95 |
+
|
96 |
+
|
97 |
+
# --- Blurnonymizer Instance ---
|
98 |
+
try:
|
99 |
+
blurnonymizer_instance = blurnonymize.ImageBlurnonymizer()
|
100 |
+
print("Blurnonymizer initialized successfully.")
|
101 |
+
except Exception as e:
|
102 |
+
print(f"Error initializing Blurnonymizer: {e}")
|
103 |
+
print(traceback.format_exc())
|
104 |
+
raise gr.Error(f"Failed to initialize Blurnonymizer. Check logs. Error: {e}")
|
105 |
+
|
106 |
+
# --- Core Processing Function ---
|
107 |
+
@spaces.GPU(duration=20) # add this so that the sam segmentation runs on the gpu
|
108 |
+
def anonymise_image(input_image_np: np.ndarray, boxes: list[BoundingBox]):
|
109 |
+
"""Calls the blurnonymizer instance to censor the image."""
|
110 |
+
if not blurnonymizer_instance:
|
111 |
+
raise gr.Error("Blurnonymizer not initialized.")
|
112 |
+
return blurnonymizer_instance.censor_image_blur_easy(
|
113 |
+
input_image_np, boxes, method="segmentation", verbose=False # Set verbose as needed
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def run_model_inference(input_image_pil: Image.Image, prompt_text: str):
|
118 |
+
"""
|
119 |
+
Runs model inference on the input image and prompt.
|
120 |
+
"""
|
121 |
+
|
122 |
+
# 1. Run Model Inference
|
123 |
+
print("Running model inference...")
|
124 |
+
messages = build_messages(
|
125 |
+
input_image_pil,
|
126 |
+
prompt=prompt_text)
|
127 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
128 |
+
|
129 |
+
# Prepare inputs for the model
|
130 |
+
inputs = tokenizer(
|
131 |
+
input_image_pil,
|
132 |
+
input_text,
|
133 |
+
return_tensors="pt",
|
134 |
+
).to("cuda")
|
135 |
+
|
136 |
+
out_tokens = model.generate(
|
137 |
+
**inputs,
|
138 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
139 |
+
use_cache=True,
|
140 |
+
temperature=TEMPERATURE,
|
141 |
+
min_p=MIN_P,
|
142 |
+
)
|
143 |
+
generated_ids_trimmed = [
|
144 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, out_tokens)
|
145 |
+
]
|
146 |
+
raw_model_output = tokenizer.batch_decode(
|
147 |
+
generated_ids_trimmed,
|
148 |
+
skip_special_tokens=True,
|
149 |
+
clean_up_tokenization_spaces=True,
|
150 |
+
)[0]
|
151 |
+
|
152 |
+
input_height = inputs['image_grid_thw'][0][1]*14
|
153 |
+
input_width = inputs['image_grid_thw'][0][2]*14
|
154 |
+
|
155 |
+
if input_height != input_image_pil.height:
|
156 |
+
print("[!] tokenized image height differs from actual height:")
|
157 |
+
print(f"Actual: {input_image_pil.height}, processed: {input_height}")
|
158 |
+
|
159 |
+
if input_width != input_image_pil.width:
|
160 |
+
print("[!] tokenized image width differs from actual width:")
|
161 |
+
print(f"Actual: {input_image_pil.width}, processed: {input_width}")
|
162 |
+
|
163 |
+
print("[+] Model inference completed.")
|
164 |
+
print("[*] Raw output:")
|
165 |
+
print(raw_model_output)
|
166 |
+
|
167 |
+
return raw_model_output, input_height, input_width
|
168 |
+
|
169 |
+
|
170 |
+
@spaces.GPU(duration=90) # Request GPU for this function, allow up to 120 seconds
|
171 |
+
def analyze_image(input_image_pil: Image.Image, prompt_text: str):
|
172 |
+
"""
|
173 |
+
Analyzes the input image using the VLM, visualizes findings, and anonymizes.
|
174 |
+
"""
|
175 |
+
if input_image_pil is None:
|
176 |
+
raise gr.Error("Please upload an image.")
|
177 |
+
if not prompt_text:
|
178 |
+
raise gr.Error("Please provide a prompt.")
|
179 |
+
|
180 |
+
original_image_np = np.array(input_image_pil)
|
181 |
+
|
182 |
+
# 1. Run Model Inference
|
183 |
+
try:
|
184 |
+
raw_model_output, image_height, image_width = run_model_inference(input_image_pil, prompt_text)
|
185 |
+
except Exception as e:
|
186 |
+
print(f"Error during model inference: {e}")
|
187 |
+
print(traceback.format_exc())
|
188 |
+
raise gr.Error(f"Model inference failed: {e}")
|
189 |
+
|
190 |
+
# 2. Parse Findings
|
191 |
+
try:
|
192 |
+
print("Parsing findings...")
|
193 |
+
# Use the provided utility functions
|
194 |
+
parsed_findings = utils.parse_into_models(
|
195 |
+
utils.parse_json_response(raw_model_output)
|
196 |
+
)
|
197 |
+
print(f"[+] Parsed {len(parsed_findings)} findings.")
|
198 |
+
if not parsed_findings:
|
199 |
+
print("[*] No findings were parsed from the model output.")
|
200 |
+
|
201 |
+
except Exception as e:
|
202 |
+
print(f"Error parsing model output: {e}")
|
203 |
+
print(traceback.format_exc())
|
204 |
+
# Don't raise error here, allow visualization/anonymization steps to proceed if possible
|
205 |
+
# or return early with only original image if parsing is critical
|
206 |
+
gr.Warning(
|
207 |
+
f"Could not parse findings from model output: {e}. Visualization and anonymization might be incomplete."
|
208 |
+
)
|
209 |
+
# Fallback: visualize/anonymize based on empty findings list if needed
|
210 |
+
parsed_findings = [] # Ensure it's an empty list for downstream steps
|
211 |
+
|
212 |
+
# Initialize boxes_for_viz before the try block
|
213 |
+
boxes_for_viz = []
|
214 |
+
try:
|
215 |
+
# 3. Visualize Findings
|
216 |
+
print("Visualizing findings...")
|
217 |
+
if parsed_findings:
|
218 |
+
# Convert Findings to BoundingBox for visualization function
|
219 |
+
boxes_for_viz = [BoundingBox.from_finding(f) for f in parsed_findings]
|
220 |
+
# Ensure image is in the correct format (np array) for visualize_boxes_annotated
|
221 |
+
visualized_image_np = utils.visualize_boxes_annotated(
|
222 |
+
original_image_np, boxes_for_viz
|
223 |
+
)
|
224 |
+
print("Visualization generated.")
|
225 |
+
else:
|
226 |
+
print("No findings to visualize, using original image.")
|
227 |
+
visualized_image_np = (
|
228 |
+
original_image_np.copy()
|
229 |
+
) # Show original if no findings
|
230 |
+
|
231 |
+
except Exception as e:
|
232 |
+
print(f"Error during visualization: {e}")
|
233 |
+
print(traceback.format_exc())
|
234 |
+
gr.Warning(f"Failed to visualize findings: {e}")
|
235 |
+
visualized_image_np = original_image_np.copy() # Fallback to original
|
236 |
+
|
237 |
+
try:
|
238 |
+
# 4. Anonymize Image
|
239 |
+
print("Anonymizing image...")
|
240 |
+
# Use the blurnonymize function with the raw output (as it might contain info needed by the func)
|
241 |
+
# Ensure image is numpy array
|
242 |
+
# Check if boxes_for_viz is populated before calling anonymise_image
|
243 |
+
if boxes_for_viz:
|
244 |
+
anonymized_image_np = anonymise_image(original_image_np, boxes_for_viz)
|
245 |
+
print("Anonymization generated.")
|
246 |
+
else:
|
247 |
+
print("No boxes found for anonymization, using original image.")
|
248 |
+
anonymized_image_np = original_image_np.copy()
|
249 |
+
|
250 |
+
except Exception as e:
|
251 |
+
print(f"Error during anonymization: {e}")
|
252 |
+
print(traceback.format_exc())
|
253 |
+
gr.Warning(f"Failed to anonymize image: {e}")
|
254 |
+
anonymized_image_np = original_image_np.copy() # Fallback to original
|
255 |
+
|
256 |
+
# Convert numpy arrays back to PIL Images for Gradio output if needed, or let Gradio handle numpy
|
257 |
+
# Gradio's gr.Image output can handle numpy arrays directly
|
258 |
+
|
259 |
+
# Return the three images
|
260 |
+
return raw_model_output, visualized_image_np, anonymized_image_np
|
261 |
+
|
262 |
+
|
263 |
+
# --- Gradio Interface ---
|
264 |
+
with gr.Blocks() as demo:
|
265 |
+
gr.Markdown("# Private Data Detection & Anonymization UI")
|
266 |
+
gr.Markdown(f"Using model: `{MODEL_NAME}` on ZeroGPU.")
|
267 |
+
|
268 |
+
with gr.Row():
|
269 |
+
with gr.Column(scale=1):
|
270 |
+
input_image = gr.Image(type="pil", label="Upload Image")
|
271 |
+
prompt_textbox = gr.Textbox(
|
272 |
+
label="Analysis Prompt", value=DEFAULT_PROMPT, lines=4
|
273 |
+
)
|
274 |
+
analyze_button = gr.Button("Analyze Image")
|
275 |
+
with gr.Column(scale=2):
|
276 |
+
with gr.Column():
|
277 |
+
raw_output = gr.Textbox(
|
278 |
+
label="Raw Model Output", interactive=False
|
279 |
+
)
|
280 |
+
output_visualized = gr.Image(
|
281 |
+
label="Detected Privacy Findings", type="numpy", interactive=False
|
282 |
+
)
|
283 |
+
output_anonymized = gr.Image(
|
284 |
+
label="Anonymized", type="numpy", interactive=False
|
285 |
+
)
|
286 |
+
|
287 |
+
analyze_button.click(
|
288 |
+
fn=analyze_image,
|
289 |
+
inputs=[input_image, prompt_textbox],
|
290 |
+
outputs=[raw_output, output_visualized, output_anonymized],
|
291 |
+
)
|
292 |
+
|
293 |
+
# --- Launch App ---
|
294 |
+
if __name__ == "__main__":
|
295 |
+
demo.queue().launch(
|
296 |
+
debug=True
|
297 |
+
) # Enable queue for handling multiple requests, debug mode for logs
|
blurnonymize.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import traceback
|
3 |
+
from typing import Literal, Optional
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import matplotlib.patches as patches
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from pydantic import BaseModel
|
11 |
+
from sam2.build_sam import build_sam2
|
12 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
13 |
+
from utils import *
|
14 |
+
|
15 |
+
|
16 |
+
# --- Utility Functions (kept outside the class) ---
|
17 |
+
|
18 |
+
def blur_image(img: np.ndarray):
|
19 |
+
"""Applies Gaussian blur to an image."""
|
20 |
+
return cv2.GaussianBlur(img, (35, 35), 50)
|
21 |
+
|
22 |
+
|
23 |
+
def plot_polygon_mask(image: np.ndarray, polygons: list[list[tuple[int, int]]]):
|
24 |
+
"""
|
25 |
+
Plots polygon-based segmentation masks on top of an image.
|
26 |
+
"""
|
27 |
+
plt.imshow(image)
|
28 |
+
for polygon in polygons:
|
29 |
+
if not polygon: continue # Skip empty polygons
|
30 |
+
polygon_array = np.array(polygon).reshape(-1, 2)
|
31 |
+
x, y = zip(*polygon_array)
|
32 |
+
x = list(x) + [x[0]]
|
33 |
+
y = list(y) + [y[0]]
|
34 |
+
plt.plot(x, y, '-r', linewidth=2)
|
35 |
+
plt.axis('off')
|
36 |
+
plt.tight_layout()
|
37 |
+
plt.show()
|
38 |
+
|
39 |
+
|
40 |
+
def visualize_boxes(image, findings):
|
41 |
+
"""Visualizes bounding boxes on an image."""
|
42 |
+
fig, ax = plt.subplots(1)
|
43 |
+
ax.imshow(image)
|
44 |
+
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
|
45 |
+
for i, finding in enumerate(findings):
|
46 |
+
[x_min, y_min, x_max, y_max] = finding.bounding_box
|
47 |
+
color = colors[i % len(colors)]
|
48 |
+
rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2, edgecolor=color,
|
49 |
+
facecolor='none')
|
50 |
+
ax.add_patch(rect)
|
51 |
+
print(f"Finding {i + 1} (Color: {color}):")
|
52 |
+
if not findings:
|
53 |
+
print("No findings")
|
54 |
+
plt.xticks(np.arange(0, image.shape[1], 50))
|
55 |
+
plt.yticks(np.arange(0, image.shape[0], 50))
|
56 |
+
plt.show()
|
57 |
+
|
58 |
+
# --- SAM Visualization Helpers (kept outside the class) ---
|
59 |
+
|
60 |
+
def show_mask(mask, ax, random_color=False, borders=True):
|
61 |
+
"""Displays a single mask on a matplotlib axis."""
|
62 |
+
if random_color:
|
63 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
64 |
+
else:
|
65 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
66 |
+
h, w = mask.shape[-2:]
|
67 |
+
mask = mask.astype(np.uint8)
|
68 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
69 |
+
if borders:
|
70 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
71 |
+
# contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] # Optional smoothing
|
72 |
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
73 |
+
ax.imshow(mask_image)
|
74 |
+
|
75 |
+
def show_points(coords, labels, ax, marker_size=375):
|
76 |
+
"""Displays points (positive/negative) on a matplotlib axis."""
|
77 |
+
pos_points = coords[labels == 1]
|
78 |
+
neg_points = coords[labels == 0]
|
79 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
80 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
81 |
+
|
82 |
+
def show_box(box, ax):
|
83 |
+
"""Displays a bounding box on a matplotlib axis."""
|
84 |
+
x0, y0 = box[0], box[1]
|
85 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
86 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
87 |
+
|
88 |
+
def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
|
89 |
+
"""Displays multiple masks resulting from SAM prediction."""
|
90 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
91 |
+
plt.figure(figsize=(10, 10))
|
92 |
+
plt.imshow(image)
|
93 |
+
show_mask(mask, plt.gca(), borders=borders)
|
94 |
+
if point_coords is not None:
|
95 |
+
assert input_labels is not None
|
96 |
+
show_points(point_coords, input_labels, plt.gca())
|
97 |
+
if box_coords is not None:
|
98 |
+
show_box(box_coords, plt.gca())
|
99 |
+
if len(scores) > 1:
|
100 |
+
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
|
101 |
+
plt.axis('off')
|
102 |
+
plt.show()
|
103 |
+
|
104 |
+
|
105 |
+
# --- ImageBlurnonymizer Class ---
|
106 |
+
|
107 |
+
class ImageBlurnonymizer:
|
108 |
+
def __init__(self, checkpoint="./sam2.1_hiera_large.pt", model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml"):
|
109 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
110 |
+
self.predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint, device=self.device))
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def _smoothen_mask(mask: np.ndarray):
|
114 |
+
"""Applies morphological closing to smoothen mask boundaries."""
|
115 |
+
kernel = np.ones((20, 20), np.uint8)
|
116 |
+
return cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def _mask_from_bbox(image_shape, bbox: tuple[int, int, int, int]):
|
120 |
+
"""Creates a simple rectangular mask from a bounding box."""
|
121 |
+
height, width, *_ = image_shape # Allow for 2D or 3D shape tuple
|
122 |
+
xmin, ymin, xmax, ymax = bbox
|
123 |
+
mask = np.zeros((height, width), dtype=np.uint8)
|
124 |
+
mask[ymin:ymax, xmin:xmax] = 1
|
125 |
+
return mask # No need for np.array() conversion
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def _apply_blur_mask(image: np.ndarray, mask: np.ndarray):
|
129 |
+
"""Applies a blur to an image based on a mask."""
|
130 |
+
if mask.ndim == 2: # Ensure mask is 3-channel for broadcasting
|
131 |
+
mask = np.stack((mask,) * image.shape[2], axis=-1)
|
132 |
+
blurred = blur_image(image) # Use the utility function
|
133 |
+
return np.where(mask, blurred, image)
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def _binary_mask_to_polygon(binary_mask: np.ndarray, epsilon=2.0):
|
137 |
+
"""Converts a binary segmentation mask to polygon contours."""
|
138 |
+
try:
|
139 |
+
converted = (binary_mask * 255).astype(np.uint8)
|
140 |
+
# Use RETR_TREE to get hierarchy, CHAIN_APPROX_SIMPLE for efficiency
|
141 |
+
contours, _ = cv2.findContours(converted, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
142 |
+
polygons = []
|
143 |
+
for contour in contours:
|
144 |
+
approx_contour = cv2.approxPolyDP(contour, epsilon, True)
|
145 |
+
# Ensure points are converted correctly
|
146 |
+
polygon = [(int(point[0][0]), int(point[0][1])) for point in approx_contour]
|
147 |
+
polygons.append(polygon)
|
148 |
+
return polygons
|
149 |
+
except Exception as e:
|
150 |
+
print(f"An error occurred during polygon conversion: {e}")
|
151 |
+
print(traceback.format_exc())
|
152 |
+
return None # Return None on error
|
153 |
+
|
154 |
+
|
155 |
+
def get_segmentation_mask(self, image: np.ndarray, bbox: tuple[int, int, int, int]):
|
156 |
+
"""
|
157 |
+
Generates a segmentation mask for a region defined by a bounding box using SAM.
|
158 |
+
|
159 |
+
Adds points within the bounding box to guide SAM towards the intended object (e.g., face)
|
160 |
+
and away from surrounding elements (e.g., hair).
|
161 |
+
"""
|
162 |
+
x_min, y_min, x_max, y_max = bbox
|
163 |
+
x_width = x_max - x_min
|
164 |
+
y_height = y_max - y_min # Corrected variable name
|
165 |
+
|
166 |
+
# Handle cases where box dimensions are too small for third calculations
|
167 |
+
x_third = x_width // 3 if x_width >= 3 else 0
|
168 |
+
y_third = y_height // 3 if y_height >= 3 else 0
|
169 |
+
|
170 |
+
center_point = [(x_min + x_max) // 2, (y_min + y_max) // 2]
|
171 |
+
|
172 |
+
# Define points ensuring they stay within the image boundaries
|
173 |
+
points = [center_point]
|
174 |
+
if y_third > 0:
|
175 |
+
points.append([center_point[0], center_point[1] - y_third])
|
176 |
+
points.append([center_point[0], center_point[1] + y_third])
|
177 |
+
if x_third > 0:
|
178 |
+
points.append([center_point[0] + x_third, center_point[1]])
|
179 |
+
points.append([center_point[0] - x_third, center_point[1]])
|
180 |
+
|
181 |
+
# Ensure points are valid coordinates (e.g., non-negative)
|
182 |
+
points = [[max(0, p[0]), max(0, p[1])] for p in points]
|
183 |
+
|
184 |
+
|
185 |
+
self.predictor.set_image(image)
|
186 |
+
masks, scores, _ = self.predictor.predict(
|
187 |
+
box=np.array(bbox), # Predictor might expect numpy array
|
188 |
+
point_coords=np.array(points),
|
189 |
+
point_labels=np.ones(len(points)), # Label 1 for inclusion
|
190 |
+
multimask_output=True,
|
191 |
+
)
|
192 |
+
|
193 |
+
# Sort masks by score and select the best one
|
194 |
+
sorted_ind = np.argsort(scores)[::-1]
|
195 |
+
best_mask = masks[sorted_ind[0]]
|
196 |
+
best_score = scores[sorted_ind[0]]
|
197 |
+
|
198 |
+
return self._smoothen_mask(best_mask), best_score
|
199 |
+
|
200 |
+
def censor_image_blur(self, image: np.ndarray, raw_out: str,
|
201 |
+
method: Optional[Literal['segmentation', 'bbox']] = 'segmentation', verbose=False):
|
202 |
+
"""
|
203 |
+
Censors an image by blurring regions identified in the raw_out string (LLM output).
|
204 |
+
"""
|
205 |
+
json_output = parse_json_response(raw_out)
|
206 |
+
# Ensure json_output is a list before passing to parse_into_models
|
207 |
+
if isinstance(json_output, dict):
|
208 |
+
findings_list = [json_output]
|
209 |
+
elif isinstance(json_output, list):
|
210 |
+
findings_list = json_output
|
211 |
+
else:
|
212 |
+
# Handle unexpected type or raise an error
|
213 |
+
print(f"Warning: Unexpected output type from parse_json_response: {type(json_output)}")
|
214 |
+
findings_list = []
|
215 |
+
|
216 |
+
parsed = parse_into_models(findings_list)
|
217 |
+
# Filter findings based on severity
|
218 |
+
filtered = [entry for entry in parsed if entry.severity > 0]
|
219 |
+
|
220 |
+
if verbose:
|
221 |
+
visualize_boxes(image, filtered) # Use external visualization
|
222 |
+
|
223 |
+
masks = []
|
224 |
+
for finding in filtered:
|
225 |
+
bbox = finding.bounding_box # Assuming finding has a 'bounding_box' attribute
|
226 |
+
if method == 'segmentation':
|
227 |
+
mask, _ = self.get_segmentation_mask(image, bbox) # Use instance method
|
228 |
+
if verbose:
|
229 |
+
polygons = self._binary_mask_to_polygon(mask)
|
230 |
+
if polygons: # Check if polygon conversion was successful
|
231 |
+
plot_polygon_mask(image, polygons) # Use external visualization
|
232 |
+
elif method == 'bbox':
|
233 |
+
mask = self._mask_from_bbox(image.shape, bbox) # Use static method
|
234 |
+
else:
|
235 |
+
print(f"Warning: Unknown method '{method}'. Defaulting to no mask for this finding.")
|
236 |
+
continue # Skip if method is invalid
|
237 |
+
|
238 |
+
masks.append(mask)
|
239 |
+
|
240 |
+
|
241 |
+
if masks: # Check if any masks were generated
|
242 |
+
# Combine masks: logical OR ensures any pixel in any mask is included
|
243 |
+
combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
|
244 |
+
for mask in masks:
|
245 |
+
# Ensure masks are boolean or uint8 for logical_or
|
246 |
+
combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(np.uint8)
|
247 |
+
|
248 |
+
return self._apply_blur_mask(image, combined_mask) # Use static method
|
249 |
+
return image # Return original image if no masks
|
250 |
+
|
251 |
+
def censor_image_blur_easy(self, image: np.ndarray, boxes: list[BoundingBox],
|
252 |
+
method: Optional[Literal['segmentation', 'bbox']] = 'segmentation', verbose=False):
|
253 |
+
"""
|
254 |
+
Censors an image by blurring regions defined by a list of BoundingBox objects.
|
255 |
+
"""
|
256 |
+
masks = []
|
257 |
+
for box in boxes:
|
258 |
+
bbox_tuple = box.to_tuple() # Convert BoundingBox object to tuple
|
259 |
+
if method == 'segmentation':
|
260 |
+
mask, _ = self.get_segmentation_mask(image, bbox_tuple)
|
261 |
+
if verbose:
|
262 |
+
polygons = self._binary_mask_to_polygon(mask)
|
263 |
+
if polygons:
|
264 |
+
plot_polygon_mask(image, polygons)
|
265 |
+
elif method == 'bbox':
|
266 |
+
mask = self._mask_from_bbox(image.shape, bbox_tuple)
|
267 |
+
else:
|
268 |
+
print(f"Warning: Unknown method '{method}'. Defaulting to no mask for this box.")
|
269 |
+
continue
|
270 |
+
|
271 |
+
masks.append(mask)
|
272 |
+
|
273 |
+
if masks:
|
274 |
+
combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
|
275 |
+
for mask in masks:
|
276 |
+
combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(np.uint8)
|
277 |
+
|
278 |
+
return self._apply_blur_mask(image, combined_mask)
|
279 |
+
return image
|
280 |
+
|
281 |
+
# Example Usage (Optional - keep outside class):
|
282 |
+
# if __name__ == '__main__':
|
283 |
+
# # Load an image
|
284 |
+
# # img = cv2.imread('path/to/your/image.jpg')
|
285 |
+
# # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB for matplotlib
|
286 |
+
|
287 |
+
# # Create an instance of the blurnonymizer
|
288 |
+
# # blurnonymizer = ImageBlurnonymizer()
|
289 |
+
|
290 |
+
# # Define bounding boxes or get raw LLM output
|
291 |
+
# # example_boxes = [BoundingBox(xmin=100, ymin=100, xmax=200, ymax=200)] # Assuming BoundingBox class exists
|
292 |
+
# # llm_output = '...' # Your raw LLM output string
|
293 |
+
|
294 |
+
# # Censor the image
|
295 |
+
# # censored_img_easy = blurnonymizer.censor_image_blur_easy(img, example_boxes, method='segmentation', verbose=True)
|
296 |
+
# # censored_img_llm = blurnonymizer.censor_image_blur(img, llm_output, method='segmentation', verbose=True)
|
297 |
+
|
298 |
+
# # Display or save the result
|
299 |
+
# # plt.imshow(censored_img_easy)
|
300 |
+
# # plt.show()
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
unsloth
|
3 |
+
transformers
|
4 |
+
torch
|
5 |
+
pydantic
|
6 |
+
numpy
|
7 |
+
pandas
|
8 |
+
Pillow
|
9 |
+
opencv-python
|
10 |
+
spaces
|
11 |
+
matplotlib
|
12 |
+
sam2
|
utils.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, field_validator
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
import matplotlib.patches as patches
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from PIL import Image
|
7 |
+
import base64
|
8 |
+
from io import BytesIO
|
9 |
+
import io
|
10 |
+
|
11 |
+
def encode_image(image: np.ndarray) -> str:
|
12 |
+
"""Encodes a NumPy array image into a base64 JPEG string.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
image: A NumPy array representing the image.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
A base64 encoded string prefixed with 'data:image/jpeg;base64,'.
|
19 |
+
"""
|
20 |
+
pil_image = Image.fromarray(image)
|
21 |
+
buffer = BytesIO()
|
22 |
+
pil_image.save(buffer, format='jpeg')
|
23 |
+
return f"data:image/jpeg;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
24 |
+
|
25 |
+
def decode_image(base64_str: str) -> np.ndarray:
|
26 |
+
"""Decodes a base64 encoded image string into a NumPy array.
|
27 |
+
|
28 |
+
Assumes the base64 string represents a valid image format (e.g., JPEG, PNG).
|
29 |
+
|
30 |
+
Args:
|
31 |
+
base64_str: The base64 encoded image string (may include prefix).
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
A NumPy array representing the decoded image.
|
35 |
+
"""
|
36 |
+
# Remove the prefix if it exists
|
37 |
+
if ',' in base64_str:
|
38 |
+
base64_str = base64_str.split(',', 1)[1]
|
39 |
+
|
40 |
+
# Decode the base64 string
|
41 |
+
image_data = base64.b64decode(base64_str)
|
42 |
+
|
43 |
+
# Convert the image data to a PIL Image
|
44 |
+
image = Image.open(io.BytesIO(image_data))
|
45 |
+
|
46 |
+
# Convert the PIL Image to a NumPy array
|
47 |
+
numpy_image = np.array(image)
|
48 |
+
|
49 |
+
return numpy_image
|
50 |
+
|
51 |
+
class Finding(BaseModel):
|
52 |
+
"""Represents a detected finding in an image, including its label,
|
53 |
+
description, explanation, bounding box coordinates, and severity level.
|
54 |
+
"""
|
55 |
+
label: str
|
56 |
+
description: str
|
57 |
+
explanation: str
|
58 |
+
bounding_box: tuple[int, int, int, int]
|
59 |
+
severity: int
|
60 |
+
|
61 |
+
@field_validator("bounding_box")
|
62 |
+
@classmethod
|
63 |
+
def validate_bounding_box(cls, value: tuple[int, int, int, int]):
|
64 |
+
"""Validates that the bounding box coordinates are logically consistent."""
|
65 |
+
if len(value) != 4:
|
66 |
+
raise ValueError("Bounding box must be a tuple of 4 integers")
|
67 |
+
if value[0] >= value[2]:
|
68 |
+
raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)")
|
69 |
+
if value[1] >= value[3]:
|
70 |
+
raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)")
|
71 |
+
return value
|
72 |
+
|
73 |
+
class BoundingBox(BaseModel):
|
74 |
+
"""Represents a bounding box with a label and explicit min/max coordinates. Assumess that the top left corner is the origin"""
|
75 |
+
label: str
|
76 |
+
x_min: int
|
77 |
+
y_min: int
|
78 |
+
x_max: int
|
79 |
+
y_max: int
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def from_finding(finding: Finding) -> 'BoundingBox':
|
83 |
+
"""Creates a BoundingBox instance from a Finding instance."""
|
84 |
+
return BoundingBox(label=finding.label, x_min=finding.bounding_box[0], y_min=finding.bounding_box[1], x_max=finding.bounding_box[2], y_max=finding.bounding_box[3])
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def from_array(label: str, box: list[int]) -> 'BoundingBox':
|
88 |
+
"""Creates a BoundingBox instance from a label and a list of coordinates."""
|
89 |
+
return BoundingBox(label=label, x_min=box[0], y_min=box[1], x_max=box[2], y_max=box[3])
|
90 |
+
|
91 |
+
def parse_json_response(out: str) -> list[dict]:
|
92 |
+
"""Extracts and parses JSON content from a string.
|
93 |
+
|
94 |
+
Handles responses potentially wrapped in <output> tags or markdown code blocks (```json).
|
95 |
+
|
96 |
+
Args:
|
97 |
+
out: The input string potentially containing JSON.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
The parsed JSON object (list or dictionary).
|
101 |
+
|
102 |
+
Raises:
|
103 |
+
ValueError: If no valid JSON content is found.
|
104 |
+
"""
|
105 |
+
start_prefix = "<output>"
|
106 |
+
end_postfix = "</output>"
|
107 |
+
start_index = out.find(start_prefix)
|
108 |
+
end_index = out.rfind(end_postfix)
|
109 |
+
|
110 |
+
if start_index == -1:
|
111 |
+
# try to load by finding ```json ``` markers
|
112 |
+
start_index = out.rfind("```json")
|
113 |
+
end_index = out.rfind("```")
|
114 |
+
if start_index == -1 or end_index == -1:
|
115 |
+
raise ValueError("No JSON found in response")
|
116 |
+
start_index += len("```json")
|
117 |
+
fixed = out[start_index:end_index]
|
118 |
+
print(f"fixed: {fixed}")
|
119 |
+
return json.loads(fixed)
|
120 |
+
|
121 |
+
start_index += len(start_prefix)
|
122 |
+
fixed = out[start_index:end_index]
|
123 |
+
fixed = fixed.strip()
|
124 |
+
if fixed.startswith("```json"):
|
125 |
+
start_index = fixed.find("[")
|
126 |
+
end_index = fixed.rfind("]")
|
127 |
+
|
128 |
+
fixed = fixed[start_index:end_index + 1]
|
129 |
+
return json.loads(fixed)
|
130 |
+
|
131 |
+
|
132 |
+
def parse_into_models(findings: list[dict]) -> list[Finding]:
|
133 |
+
"""Parses and validates a list of dictionaries into a list of Finding models.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
findings: A list of dictionaries, each representing a finding.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
A list of validated Finding model instances.
|
140 |
+
"""
|
141 |
+
parsed = []
|
142 |
+
for box in findings:
|
143 |
+
model_finding = Finding.model_validate(box)
|
144 |
+
parsed.append(model_finding)
|
145 |
+
return parsed
|
146 |
+
|
147 |
+
|
148 |
+
def parse_all_safe(out: str) -> list[Finding] | None:
|
149 |
+
"""Safely parses a string potentially containing JSON findings into Finding models.
|
150 |
+
|
151 |
+
Combines `parse_json_response` and `parse_into_models`, returning None on any parsing error.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
out: The input string.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
A list of Finding models if parsing is successful, otherwise None.
|
158 |
+
"""
|
159 |
+
try:
|
160 |
+
return parse_into_models(parse_json_response(out))
|
161 |
+
except Exception:
|
162 |
+
return None
|
163 |
+
|
164 |
+
|
165 |
+
def clamp(num: int | float, min_num: int | float = 0, max_num: int | float = 255) -> int | float:
|
166 |
+
"""Clamps a number within a specified range [min_num, max_num]."""
|
167 |
+
return max(min_num, min(num, max_num))
|
168 |
+
|
169 |
+
def enlarge_boxes(image_shape: tuple[int, int], findings: list[Finding], factor: float = 1.1) -> list[Finding]:
|
170 |
+
"""Enlarges the bounding boxes of findings by a given factor, clamping to image boundaries.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
image_shape: A tuple (height, width) representing the image dimensions.
|
174 |
+
findings: A list of Finding objects.
|
175 |
+
factor: The factor by which to enlarge the boxes (e.g., 1.1 for 10% larger).
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
A new list of Finding objects with adjusted bounding boxes.
|
179 |
+
"""
|
180 |
+
adjusted = []
|
181 |
+
img_height, img_width = image_shape
|
182 |
+
for box in findings:
|
183 |
+
x_min_orig, y_min_orig, x_max_orig, y_max_orig = box.bounding_box
|
184 |
+
x_width = x_max_orig - x_min_orig
|
185 |
+
y_width = y_max_orig - y_min_orig
|
186 |
+
|
187 |
+
# Calculate the amount to adjust on each side
|
188 |
+
x_adjust = (x_width * (factor - 1)) / 2
|
189 |
+
y_adjust = (y_width * (factor - 1)) / 2
|
190 |
+
|
191 |
+
# Calculate new coordinates and clamp them
|
192 |
+
x_min = clamp(x_min_orig - x_adjust, 0, img_width)
|
193 |
+
y_min = clamp(y_min_orig - y_adjust, 0, img_height)
|
194 |
+
x_max = clamp(x_max_orig + x_adjust, 0, img_width)
|
195 |
+
y_max = clamp(y_max_orig + y_adjust, 0, img_height)
|
196 |
+
|
197 |
+
# Ensure coordinates remain valid integers if they were originally
|
198 |
+
adjusted_bbox = (int(round(x_min)), int(round(y_min)), int(round(x_max)), int(round(y_max)))
|
199 |
+
|
200 |
+
# Validate adjusted box before creating new Finding
|
201 |
+
try:
|
202 |
+
Finding.validate_bounding_box(adjusted_bbox)
|
203 |
+
adjusted.append(box.model_copy(update={'bounding_box': adjusted_bbox}))
|
204 |
+
except ValueError:
|
205 |
+
# If enlarging makes the box invalid (e.g., min >= max), keep the original
|
206 |
+
adjusted.append(box) # Or handle the error differently if needed
|
207 |
+
|
208 |
+
return adjusted
|
209 |
+
|
210 |
+
def change_box_format(shape: tuple[int, int, int], box: tuple[int, int, int, int]) -> tuple[float, float, float, float]:
|
211 |
+
"""Normalizes bounding box coordinates from a 1000x1000 grid to the image dimensions.
|
212 |
+
This is only for gemini based models, as they returns coordinates normalized between 0-1000
|
213 |
+
Qwen based models don't need this
|
214 |
+
Assumes the input box coordinates are relative to a 1000x1000 grid.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
shape: The shape of the target image (height, width, channels).
|
218 |
+
box: The bounding box tuple (x_min, y_min, x_max, y_max) in 1000x1000 coordinates.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
A tuple of normalized bounding box coordinates (x_min, y_min, x_max, y_max)
|
222 |
+
relative to the image dimensions.
|
223 |
+
"""
|
224 |
+
y_height, x_width, _ = shape
|
225 |
+
# Normalize coordinates from 1000x1000 grid to image dimensions
|
226 |
+
x_min = (box[0] / 1000.0) * x_width
|
227 |
+
y_min = (box[1] / 1000.0) * y_height
|
228 |
+
x_max = (box[2] / 1000.0) * x_width
|
229 |
+
y_max = (box[3] / 1000.0) * y_height
|
230 |
+
|
231 |
+
return (x_min, y_min, x_max, y_max)
|
232 |
+
|
233 |
+
def normalize_findings_boxes(shape: tuple[int, int, int], findings: list[Finding]) -> list[Finding]:
|
234 |
+
"""Normalizes the bounding boxes of all findings in a list.
|
235 |
+
This is only for gemini based models, as they returns coordinates normalized between 0-1000
|
236 |
+
Qwen based models don't need this
|
237 |
+
|
238 |
+
Modifies the findings list in-place.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
shape: The shape of the target image (height, width, channels).
|
242 |
+
findings: A list of Finding objects whose bounding boxes need normalization.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
The list of Finding objects with normalized bounding boxes (modified in-place).
|
246 |
+
"""
|
247 |
+
for finding in findings:
|
248 |
+
# Ensure the bounding box is a tuple before passing
|
249 |
+
current_box = tuple(finding.bounding_box)
|
250 |
+
finding.bounding_box = change_box_format(shape, current_box)
|
251 |
+
return findings
|
252 |
+
|
253 |
+
def change_box_format(shape, box):
|
254 |
+
y_width, x_width, _ = shape
|
255 |
+
# so apparently the bounding box always refers to a 1000x1000 grid
|
256 |
+
# so we need to normalize
|
257 |
+
# i assume that it has to do with the way their image embeddings work
|
258 |
+
x_min = (box[0] / 1000) * x_width
|
259 |
+
y_min = (box[1] / 1000) * y_width
|
260 |
+
x_max = (box[2] / 1000) * x_width
|
261 |
+
y_max = (box[3] / 1000) * y_width
|
262 |
+
|
263 |
+
return [x_min, y_min, x_max, y_max]
|
264 |
+
|
265 |
+
def normalize_findings_boxes(shape, findings):
|
266 |
+
for finding in findings:
|
267 |
+
finding.bounding_box = change_box_format(shape, finding.bounding_box)
|
268 |
+
return findings
|
269 |
+
|
270 |
+
def visualize_boxes(image, findings):
|
271 |
+
# Create a figure and axis
|
272 |
+
fig, ax = plt.subplots(1)
|
273 |
+
ax.imshow(image)
|
274 |
+
|
275 |
+
# Define a list of colors for the boxes
|
276 |
+
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
|
277 |
+
|
278 |
+
for i, finding in enumerate(findings):
|
279 |
+
[x_min, y_min, x_max, y_max] = finding.bounding_box
|
280 |
+
|
281 |
+
# Select a color for the current box
|
282 |
+
color = colors[i % len(colors)]
|
283 |
+
|
284 |
+
rect = patches.Rectangle((x_min, y_min),
|
285 |
+
x_max - x_min,
|
286 |
+
y_max - y_min,
|
287 |
+
linewidth=2, edgecolor=color, facecolor='none')
|
288 |
+
|
289 |
+
ax.add_patch(rect)
|
290 |
+
|
291 |
+
# Print the whole finding and the color of its box
|
292 |
+
print(f"Finding {i+1} (Color: {color}):")
|
293 |
+
if (len(findings) == 0):
|
294 |
+
print("No findings")
|
295 |
+
# Set x-axis ticks every 2 units
|
296 |
+
#plt.xticks(np.arange(0, image.shape[1], 50)) # Start, Stop, Step
|
297 |
+
#plt.yticks(np.arange(0, image.shape[0], 50)) # Start, Stop, Step
|
298 |
+
|
299 |
+
plt.show()
|
300 |
+
|
301 |
+
def visualize_boxes_annotated(image: np.ndarray | Image.Image, boxes: list[BoundingBox]) -> np.ndarray:
|
302 |
+
"""Draws bounding boxes with labels on an image and returns the annotated image as a NumPy array.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
image: The input image (NumPy array or PIL Image).
|
306 |
+
boxes: A list of BoundingBox objects with coordinates relative to the image.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
A NumPy array representing the image with annotated bounding boxes.
|
310 |
+
"""
|
311 |
+
if not isinstance(image, np.ndarray):
|
312 |
+
image = np.array(image)
|
313 |
+
# Create a figure and axis with high DPI
|
314 |
+
fig = plt.figure(dpi=300)
|
315 |
+
ax = plt.subplot(111)
|
316 |
+
ax.imshow(image)
|
317 |
+
ax.set_axis_off()
|
318 |
+
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
|
319 |
+
|
320 |
+
# Define a list of colors for the boxes
|
321 |
+
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
|
322 |
+
|
323 |
+
for i, box in enumerate(boxes):
|
324 |
+
x_min = box.x_min
|
325 |
+
y_min = box.y_min
|
326 |
+
x_max = box.x_max
|
327 |
+
y_max = box.y_max
|
328 |
+
label = box.label
|
329 |
+
|
330 |
+
# Select a color for the current box
|
331 |
+
color = colors[i % len(colors)]
|
332 |
+
|
333 |
+
rect = patches.Rectangle((x_min, y_min),
|
334 |
+
x_max - x_min,
|
335 |
+
y_max - y_min,
|
336 |
+
linewidth=1, edgecolor=color, facecolor='none')
|
337 |
+
|
338 |
+
ax.add_patch(rect)
|
339 |
+
|
340 |
+
# Add label text above the box
|
341 |
+
ax.text(x_min, y_min-5, label, color=color, fontsize=10,
|
342 |
+
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
|
343 |
+
|
344 |
+
# Instead of displaying, save to numpy array
|
345 |
+
fig.canvas.draw()
|
346 |
+
data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
347 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))
|
348 |
+
# Convert RGBA to RGB
|
349 |
+
data = data[:, :, :3]
|
350 |
+
plt.close()
|
351 |
+
return data
|