hugohabicht01 commited on
Commit
335bcd6
·
1 Parent(s): c8cd915
Files changed (4) hide show
  1. app.py +297 -0
  2. blurnonymize.py +300 -0
  3. requirements.txt +12 -0
  4. utils.py +351 -0
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from unsloth import FastVisionModel
4
+ import torch
5
+ from PIL import Image
6
+ import numpy as np
7
+ import traceback
8
+ from typing import Any, Optional
9
+
10
+ # Import user-provided modules
11
+ import utils
12
+ from utils import Finding, BoundingBox # Explicitly import needed classes
13
+ import blurnonymize
14
+
15
+ # --- Constants ---
16
+ MODEL_NAME = "cborg/qwen2.5VL-3b-privacydetector"
17
+ MAX_NEW_TOKENS = 2048
18
+ TEMPERATURE = 1.0
19
+ MIN_P = 0.1
20
+ SYSTEM_PROMPT = """You are a helpful assistant for privacy analysis of images. Please always answer in English. Please obey the users instructions and follow the provided format."""
21
+ DEFAULT_PROMPT = """
22
+ You are an expert at pixel perfect image analysis and in privacy.
23
+ First write down your thoughts within a <think> block.
24
+ Please go through all objects in the image and consider whether they are private data or not.
25
+ End this with a </think> block.
26
+
27
+ After going through everything, output your findings in an <output></output> block as a json list with the following keys:
28
+ {"label": <|object_ref_start|>str<|object_ref_end|>, "description": str, "explanation": str, "bounding_box": <|box_start|>[x_min, y_min, x_max, y_max]<|box_end|>, "severity": int}
29
+
30
+ Some things to remember:
31
+
32
+ - private data is only data thats linked to a human person, common examples being a persons face, name, address, license plate
33
+ - whenever something can be used to identify a unique human person, it is private data
34
+ - report sensitive data as well, such as a nude person
35
+ - Severity is a number between 0 and 10, with 0 being not private data and 10 being extremely sensitive private data.
36
+ - don't report items which dont contain private data in the final output, you may mention them in your thoughts
37
+ - animals and animal faces are not personal data, so a giraffe or a dog is not private data
38
+ - you can use whatever format you want within the <think> </think> blocks
39
+ - only output valid JSON in between the <output> </output> blocks, adhering to the schema provided
40
+ - output the bounding box always as an array of form [x_min, y_min, x_max, y_max]
41
+ - private data have a severity greater than 0, so a human face would have severity 6
42
+ - go through the image step by step and report the private data, its better to be a bit too sensitive than to miss anything
43
+ - put the bounding boxes around the human's face and not the entire person when reporting people as personal data
44
+ - Think step by step, take your time.
45
+
46
+ Here is the image to analyse, start your analysis directly after:
47
+ """
48
+
49
+
50
+ def build_messages(image, history: Optional[list[dict[str, Any]]] = None, prompt: Optional[str] = None):
51
+ if not prompt:
52
+ prompt = DEFAULT_PROMPT
53
+
54
+ if history:
55
+ return [
56
+ *history,
57
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
58
+ ]
59
+
60
+ return [
61
+ {
62
+ "role": "system",
63
+ "content": [
64
+ {
65
+ "type": "text",
66
+ "text": SYSTEM_PROMPT,
67
+ }
68
+ ],
69
+ },
70
+ {
71
+ "role": "user",
72
+ "content": [
73
+ {"type": "text", "text": prompt},
74
+ {"type": "image", "image": image},
75
+ ],
76
+ },
77
+ ]
78
+
79
+
80
+ # --- Model Loading ---
81
+ # Load model using unsloth for 4-bit quantization
82
+ try:
83
+ model, tokenizer = FastVisionModel.from_pretrained(
84
+ model_name=MODEL_NAME,
85
+ load_in_4bit=True,
86
+ )
87
+ FastVisionModel.for_inference(model)
88
+ model.to("cuda").eval() # Ensure model is on GPU and in eval mode
89
+ print("Model loaded successfully.")
90
+ except Exception as e:
91
+ print(f"Error loading model: {e}")
92
+ print(traceback.format_exc())
93
+ # Optionally raise or handle the error to prevent app launch if model fails
94
+ raise gr.Error(f"Failed to load model {MODEL_NAME}. Check logs. Error: {e}")
95
+
96
+
97
+ # --- Blurnonymizer Instance ---
98
+ try:
99
+ blurnonymizer_instance = blurnonymize.ImageBlurnonymizer()
100
+ print("Blurnonymizer initialized successfully.")
101
+ except Exception as e:
102
+ print(f"Error initializing Blurnonymizer: {e}")
103
+ print(traceback.format_exc())
104
+ raise gr.Error(f"Failed to initialize Blurnonymizer. Check logs. Error: {e}")
105
+
106
+ # --- Core Processing Function ---
107
+ @spaces.GPU(duration=20) # add this so that the sam segmentation runs on the gpu
108
+ def anonymise_image(input_image_np: np.ndarray, boxes: list[BoundingBox]):
109
+ """Calls the blurnonymizer instance to censor the image."""
110
+ if not blurnonymizer_instance:
111
+ raise gr.Error("Blurnonymizer not initialized.")
112
+ return blurnonymizer_instance.censor_image_blur_easy(
113
+ input_image_np, boxes, method="segmentation", verbose=False # Set verbose as needed
114
+ )
115
+
116
+
117
+ def run_model_inference(input_image_pil: Image.Image, prompt_text: str):
118
+ """
119
+ Runs model inference on the input image and prompt.
120
+ """
121
+
122
+ # 1. Run Model Inference
123
+ print("Running model inference...")
124
+ messages = build_messages(
125
+ input_image_pil,
126
+ prompt=prompt_text)
127
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
128
+
129
+ # Prepare inputs for the model
130
+ inputs = tokenizer(
131
+ input_image_pil,
132
+ input_text,
133
+ return_tensors="pt",
134
+ ).to("cuda")
135
+
136
+ out_tokens = model.generate(
137
+ **inputs,
138
+ max_new_tokens=MAX_NEW_TOKENS,
139
+ use_cache=True,
140
+ temperature=TEMPERATURE,
141
+ min_p=MIN_P,
142
+ )
143
+ generated_ids_trimmed = [
144
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, out_tokens)
145
+ ]
146
+ raw_model_output = tokenizer.batch_decode(
147
+ generated_ids_trimmed,
148
+ skip_special_tokens=True,
149
+ clean_up_tokenization_spaces=True,
150
+ )[0]
151
+
152
+ input_height = inputs['image_grid_thw'][0][1]*14
153
+ input_width = inputs['image_grid_thw'][0][2]*14
154
+
155
+ if input_height != input_image_pil.height:
156
+ print("[!] tokenized image height differs from actual height:")
157
+ print(f"Actual: {input_image_pil.height}, processed: {input_height}")
158
+
159
+ if input_width != input_image_pil.width:
160
+ print("[!] tokenized image width differs from actual width:")
161
+ print(f"Actual: {input_image_pil.width}, processed: {input_width}")
162
+
163
+ print("[+] Model inference completed.")
164
+ print("[*] Raw output:")
165
+ print(raw_model_output)
166
+
167
+ return raw_model_output, input_height, input_width
168
+
169
+
170
+ @spaces.GPU(duration=90) # Request GPU for this function, allow up to 120 seconds
171
+ def analyze_image(input_image_pil: Image.Image, prompt_text: str):
172
+ """
173
+ Analyzes the input image using the VLM, visualizes findings, and anonymizes.
174
+ """
175
+ if input_image_pil is None:
176
+ raise gr.Error("Please upload an image.")
177
+ if not prompt_text:
178
+ raise gr.Error("Please provide a prompt.")
179
+
180
+ original_image_np = np.array(input_image_pil)
181
+
182
+ # 1. Run Model Inference
183
+ try:
184
+ raw_model_output, image_height, image_width = run_model_inference(input_image_pil, prompt_text)
185
+ except Exception as e:
186
+ print(f"Error during model inference: {e}")
187
+ print(traceback.format_exc())
188
+ raise gr.Error(f"Model inference failed: {e}")
189
+
190
+ # 2. Parse Findings
191
+ try:
192
+ print("Parsing findings...")
193
+ # Use the provided utility functions
194
+ parsed_findings = utils.parse_into_models(
195
+ utils.parse_json_response(raw_model_output)
196
+ )
197
+ print(f"[+] Parsed {len(parsed_findings)} findings.")
198
+ if not parsed_findings:
199
+ print("[*] No findings were parsed from the model output.")
200
+
201
+ except Exception as e:
202
+ print(f"Error parsing model output: {e}")
203
+ print(traceback.format_exc())
204
+ # Don't raise error here, allow visualization/anonymization steps to proceed if possible
205
+ # or return early with only original image if parsing is critical
206
+ gr.Warning(
207
+ f"Could not parse findings from model output: {e}. Visualization and anonymization might be incomplete."
208
+ )
209
+ # Fallback: visualize/anonymize based on empty findings list if needed
210
+ parsed_findings = [] # Ensure it's an empty list for downstream steps
211
+
212
+ # Initialize boxes_for_viz before the try block
213
+ boxes_for_viz = []
214
+ try:
215
+ # 3. Visualize Findings
216
+ print("Visualizing findings...")
217
+ if parsed_findings:
218
+ # Convert Findings to BoundingBox for visualization function
219
+ boxes_for_viz = [BoundingBox.from_finding(f) for f in parsed_findings]
220
+ # Ensure image is in the correct format (np array) for visualize_boxes_annotated
221
+ visualized_image_np = utils.visualize_boxes_annotated(
222
+ original_image_np, boxes_for_viz
223
+ )
224
+ print("Visualization generated.")
225
+ else:
226
+ print("No findings to visualize, using original image.")
227
+ visualized_image_np = (
228
+ original_image_np.copy()
229
+ ) # Show original if no findings
230
+
231
+ except Exception as e:
232
+ print(f"Error during visualization: {e}")
233
+ print(traceback.format_exc())
234
+ gr.Warning(f"Failed to visualize findings: {e}")
235
+ visualized_image_np = original_image_np.copy() # Fallback to original
236
+
237
+ try:
238
+ # 4. Anonymize Image
239
+ print("Anonymizing image...")
240
+ # Use the blurnonymize function with the raw output (as it might contain info needed by the func)
241
+ # Ensure image is numpy array
242
+ # Check if boxes_for_viz is populated before calling anonymise_image
243
+ if boxes_for_viz:
244
+ anonymized_image_np = anonymise_image(original_image_np, boxes_for_viz)
245
+ print("Anonymization generated.")
246
+ else:
247
+ print("No boxes found for anonymization, using original image.")
248
+ anonymized_image_np = original_image_np.copy()
249
+
250
+ except Exception as e:
251
+ print(f"Error during anonymization: {e}")
252
+ print(traceback.format_exc())
253
+ gr.Warning(f"Failed to anonymize image: {e}")
254
+ anonymized_image_np = original_image_np.copy() # Fallback to original
255
+
256
+ # Convert numpy arrays back to PIL Images for Gradio output if needed, or let Gradio handle numpy
257
+ # Gradio's gr.Image output can handle numpy arrays directly
258
+
259
+ # Return the three images
260
+ return raw_model_output, visualized_image_np, anonymized_image_np
261
+
262
+
263
+ # --- Gradio Interface ---
264
+ with gr.Blocks() as demo:
265
+ gr.Markdown("# Private Data Detection & Anonymization UI")
266
+ gr.Markdown(f"Using model: `{MODEL_NAME}` on ZeroGPU.")
267
+
268
+ with gr.Row():
269
+ with gr.Column(scale=1):
270
+ input_image = gr.Image(type="pil", label="Upload Image")
271
+ prompt_textbox = gr.Textbox(
272
+ label="Analysis Prompt", value=DEFAULT_PROMPT, lines=4
273
+ )
274
+ analyze_button = gr.Button("Analyze Image")
275
+ with gr.Column(scale=2):
276
+ with gr.Column():
277
+ raw_output = gr.Textbox(
278
+ label="Raw Model Output", interactive=False
279
+ )
280
+ output_visualized = gr.Image(
281
+ label="Detected Privacy Findings", type="numpy", interactive=False
282
+ )
283
+ output_anonymized = gr.Image(
284
+ label="Anonymized", type="numpy", interactive=False
285
+ )
286
+
287
+ analyze_button.click(
288
+ fn=analyze_image,
289
+ inputs=[input_image, prompt_textbox],
290
+ outputs=[raw_output, output_visualized, output_anonymized],
291
+ )
292
+
293
+ # --- Launch App ---
294
+ if __name__ == "__main__":
295
+ demo.queue().launch(
296
+ debug=True
297
+ ) # Enable queue for handling multiple requests, debug mode for logs
blurnonymize.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import traceback
3
+ from typing import Literal, Optional
4
+
5
+ import cv2
6
+ import matplotlib.patches as patches
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ from pydantic import BaseModel
11
+ from sam2.build_sam import build_sam2
12
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
13
+ from utils import *
14
+
15
+
16
+ # --- Utility Functions (kept outside the class) ---
17
+
18
+ def blur_image(img: np.ndarray):
19
+ """Applies Gaussian blur to an image."""
20
+ return cv2.GaussianBlur(img, (35, 35), 50)
21
+
22
+
23
+ def plot_polygon_mask(image: np.ndarray, polygons: list[list[tuple[int, int]]]):
24
+ """
25
+ Plots polygon-based segmentation masks on top of an image.
26
+ """
27
+ plt.imshow(image)
28
+ for polygon in polygons:
29
+ if not polygon: continue # Skip empty polygons
30
+ polygon_array = np.array(polygon).reshape(-1, 2)
31
+ x, y = zip(*polygon_array)
32
+ x = list(x) + [x[0]]
33
+ y = list(y) + [y[0]]
34
+ plt.plot(x, y, '-r', linewidth=2)
35
+ plt.axis('off')
36
+ plt.tight_layout()
37
+ plt.show()
38
+
39
+
40
+ def visualize_boxes(image, findings):
41
+ """Visualizes bounding boxes on an image."""
42
+ fig, ax = plt.subplots(1)
43
+ ax.imshow(image)
44
+ colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
45
+ for i, finding in enumerate(findings):
46
+ [x_min, y_min, x_max, y_max] = finding.bounding_box
47
+ color = colors[i % len(colors)]
48
+ rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2, edgecolor=color,
49
+ facecolor='none')
50
+ ax.add_patch(rect)
51
+ print(f"Finding {i + 1} (Color: {color}):")
52
+ if not findings:
53
+ print("No findings")
54
+ plt.xticks(np.arange(0, image.shape[1], 50))
55
+ plt.yticks(np.arange(0, image.shape[0], 50))
56
+ plt.show()
57
+
58
+ # --- SAM Visualization Helpers (kept outside the class) ---
59
+
60
+ def show_mask(mask, ax, random_color=False, borders=True):
61
+ """Displays a single mask on a matplotlib axis."""
62
+ if random_color:
63
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
64
+ else:
65
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
66
+ h, w = mask.shape[-2:]
67
+ mask = mask.astype(np.uint8)
68
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
69
+ if borders:
70
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
71
+ # contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] # Optional smoothing
72
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
73
+ ax.imshow(mask_image)
74
+
75
+ def show_points(coords, labels, ax, marker_size=375):
76
+ """Displays points (positive/negative) on a matplotlib axis."""
77
+ pos_points = coords[labels == 1]
78
+ neg_points = coords[labels == 0]
79
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
80
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
81
+
82
+ def show_box(box, ax):
83
+ """Displays a bounding box on a matplotlib axis."""
84
+ x0, y0 = box[0], box[1]
85
+ w, h = box[2] - box[0], box[3] - box[1]
86
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
87
+
88
+ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
89
+ """Displays multiple masks resulting from SAM prediction."""
90
+ for i, (mask, score) in enumerate(zip(masks, scores)):
91
+ plt.figure(figsize=(10, 10))
92
+ plt.imshow(image)
93
+ show_mask(mask, plt.gca(), borders=borders)
94
+ if point_coords is not None:
95
+ assert input_labels is not None
96
+ show_points(point_coords, input_labels, plt.gca())
97
+ if box_coords is not None:
98
+ show_box(box_coords, plt.gca())
99
+ if len(scores) > 1:
100
+ plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
101
+ plt.axis('off')
102
+ plt.show()
103
+
104
+
105
+ # --- ImageBlurnonymizer Class ---
106
+
107
+ class ImageBlurnonymizer:
108
+ def __init__(self, checkpoint="./sam2.1_hiera_large.pt", model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml"):
109
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
110
+ self.predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint, device=self.device))
111
+
112
+ @staticmethod
113
+ def _smoothen_mask(mask: np.ndarray):
114
+ """Applies morphological closing to smoothen mask boundaries."""
115
+ kernel = np.ones((20, 20), np.uint8)
116
+ return cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
117
+
118
+ @staticmethod
119
+ def _mask_from_bbox(image_shape, bbox: tuple[int, int, int, int]):
120
+ """Creates a simple rectangular mask from a bounding box."""
121
+ height, width, *_ = image_shape # Allow for 2D or 3D shape tuple
122
+ xmin, ymin, xmax, ymax = bbox
123
+ mask = np.zeros((height, width), dtype=np.uint8)
124
+ mask[ymin:ymax, xmin:xmax] = 1
125
+ return mask # No need for np.array() conversion
126
+
127
+ @staticmethod
128
+ def _apply_blur_mask(image: np.ndarray, mask: np.ndarray):
129
+ """Applies a blur to an image based on a mask."""
130
+ if mask.ndim == 2: # Ensure mask is 3-channel for broadcasting
131
+ mask = np.stack((mask,) * image.shape[2], axis=-1)
132
+ blurred = blur_image(image) # Use the utility function
133
+ return np.where(mask, blurred, image)
134
+
135
+ @staticmethod
136
+ def _binary_mask_to_polygon(binary_mask: np.ndarray, epsilon=2.0):
137
+ """Converts a binary segmentation mask to polygon contours."""
138
+ try:
139
+ converted = (binary_mask * 255).astype(np.uint8)
140
+ # Use RETR_TREE to get hierarchy, CHAIN_APPROX_SIMPLE for efficiency
141
+ contours, _ = cv2.findContours(converted, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
142
+ polygons = []
143
+ for contour in contours:
144
+ approx_contour = cv2.approxPolyDP(contour, epsilon, True)
145
+ # Ensure points are converted correctly
146
+ polygon = [(int(point[0][0]), int(point[0][1])) for point in approx_contour]
147
+ polygons.append(polygon)
148
+ return polygons
149
+ except Exception as e:
150
+ print(f"An error occurred during polygon conversion: {e}")
151
+ print(traceback.format_exc())
152
+ return None # Return None on error
153
+
154
+
155
+ def get_segmentation_mask(self, image: np.ndarray, bbox: tuple[int, int, int, int]):
156
+ """
157
+ Generates a segmentation mask for a region defined by a bounding box using SAM.
158
+
159
+ Adds points within the bounding box to guide SAM towards the intended object (e.g., face)
160
+ and away from surrounding elements (e.g., hair).
161
+ """
162
+ x_min, y_min, x_max, y_max = bbox
163
+ x_width = x_max - x_min
164
+ y_height = y_max - y_min # Corrected variable name
165
+
166
+ # Handle cases where box dimensions are too small for third calculations
167
+ x_third = x_width // 3 if x_width >= 3 else 0
168
+ y_third = y_height // 3 if y_height >= 3 else 0
169
+
170
+ center_point = [(x_min + x_max) // 2, (y_min + y_max) // 2]
171
+
172
+ # Define points ensuring they stay within the image boundaries
173
+ points = [center_point]
174
+ if y_third > 0:
175
+ points.append([center_point[0], center_point[1] - y_third])
176
+ points.append([center_point[0], center_point[1] + y_third])
177
+ if x_third > 0:
178
+ points.append([center_point[0] + x_third, center_point[1]])
179
+ points.append([center_point[0] - x_third, center_point[1]])
180
+
181
+ # Ensure points are valid coordinates (e.g., non-negative)
182
+ points = [[max(0, p[0]), max(0, p[1])] for p in points]
183
+
184
+
185
+ self.predictor.set_image(image)
186
+ masks, scores, _ = self.predictor.predict(
187
+ box=np.array(bbox), # Predictor might expect numpy array
188
+ point_coords=np.array(points),
189
+ point_labels=np.ones(len(points)), # Label 1 for inclusion
190
+ multimask_output=True,
191
+ )
192
+
193
+ # Sort masks by score and select the best one
194
+ sorted_ind = np.argsort(scores)[::-1]
195
+ best_mask = masks[sorted_ind[0]]
196
+ best_score = scores[sorted_ind[0]]
197
+
198
+ return self._smoothen_mask(best_mask), best_score
199
+
200
+ def censor_image_blur(self, image: np.ndarray, raw_out: str,
201
+ method: Optional[Literal['segmentation', 'bbox']] = 'segmentation', verbose=False):
202
+ """
203
+ Censors an image by blurring regions identified in the raw_out string (LLM output).
204
+ """
205
+ json_output = parse_json_response(raw_out)
206
+ # Ensure json_output is a list before passing to parse_into_models
207
+ if isinstance(json_output, dict):
208
+ findings_list = [json_output]
209
+ elif isinstance(json_output, list):
210
+ findings_list = json_output
211
+ else:
212
+ # Handle unexpected type or raise an error
213
+ print(f"Warning: Unexpected output type from parse_json_response: {type(json_output)}")
214
+ findings_list = []
215
+
216
+ parsed = parse_into_models(findings_list)
217
+ # Filter findings based on severity
218
+ filtered = [entry for entry in parsed if entry.severity > 0]
219
+
220
+ if verbose:
221
+ visualize_boxes(image, filtered) # Use external visualization
222
+
223
+ masks = []
224
+ for finding in filtered:
225
+ bbox = finding.bounding_box # Assuming finding has a 'bounding_box' attribute
226
+ if method == 'segmentation':
227
+ mask, _ = self.get_segmentation_mask(image, bbox) # Use instance method
228
+ if verbose:
229
+ polygons = self._binary_mask_to_polygon(mask)
230
+ if polygons: # Check if polygon conversion was successful
231
+ plot_polygon_mask(image, polygons) # Use external visualization
232
+ elif method == 'bbox':
233
+ mask = self._mask_from_bbox(image.shape, bbox) # Use static method
234
+ else:
235
+ print(f"Warning: Unknown method '{method}'. Defaulting to no mask for this finding.")
236
+ continue # Skip if method is invalid
237
+
238
+ masks.append(mask)
239
+
240
+
241
+ if masks: # Check if any masks were generated
242
+ # Combine masks: logical OR ensures any pixel in any mask is included
243
+ combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
244
+ for mask in masks:
245
+ # Ensure masks are boolean or uint8 for logical_or
246
+ combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(np.uint8)
247
+
248
+ return self._apply_blur_mask(image, combined_mask) # Use static method
249
+ return image # Return original image if no masks
250
+
251
+ def censor_image_blur_easy(self, image: np.ndarray, boxes: list[BoundingBox],
252
+ method: Optional[Literal['segmentation', 'bbox']] = 'segmentation', verbose=False):
253
+ """
254
+ Censors an image by blurring regions defined by a list of BoundingBox objects.
255
+ """
256
+ masks = []
257
+ for box in boxes:
258
+ bbox_tuple = box.to_tuple() # Convert BoundingBox object to tuple
259
+ if method == 'segmentation':
260
+ mask, _ = self.get_segmentation_mask(image, bbox_tuple)
261
+ if verbose:
262
+ polygons = self._binary_mask_to_polygon(mask)
263
+ if polygons:
264
+ plot_polygon_mask(image, polygons)
265
+ elif method == 'bbox':
266
+ mask = self._mask_from_bbox(image.shape, bbox_tuple)
267
+ else:
268
+ print(f"Warning: Unknown method '{method}'. Defaulting to no mask for this box.")
269
+ continue
270
+
271
+ masks.append(mask)
272
+
273
+ if masks:
274
+ combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
275
+ for mask in masks:
276
+ combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(np.uint8)
277
+
278
+ return self._apply_blur_mask(image, combined_mask)
279
+ return image
280
+
281
+ # Example Usage (Optional - keep outside class):
282
+ # if __name__ == '__main__':
283
+ # # Load an image
284
+ # # img = cv2.imread('path/to/your/image.jpg')
285
+ # # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB for matplotlib
286
+
287
+ # # Create an instance of the blurnonymizer
288
+ # # blurnonymizer = ImageBlurnonymizer()
289
+
290
+ # # Define bounding boxes or get raw LLM output
291
+ # # example_boxes = [BoundingBox(xmin=100, ymin=100, xmax=200, ymax=200)] # Assuming BoundingBox class exists
292
+ # # llm_output = '...' # Your raw LLM output string
293
+
294
+ # # Censor the image
295
+ # # censored_img_easy = blurnonymizer.censor_image_blur_easy(img, example_boxes, method='segmentation', verbose=True)
296
+ # # censored_img_llm = blurnonymizer.censor_image_blur(img, llm_output, method='segmentation', verbose=True)
297
+
298
+ # # Display or save the result
299
+ # # plt.imshow(censored_img_easy)
300
+ # # plt.show()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ unsloth
3
+ transformers
4
+ torch
5
+ pydantic
6
+ numpy
7
+ pandas
8
+ Pillow
9
+ opencv-python
10
+ spaces
11
+ matplotlib
12
+ sam2
utils.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, field_validator
2
+ import numpy as np
3
+ import json
4
+ import matplotlib.patches as patches
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import base64
8
+ from io import BytesIO
9
+ import io
10
+
11
+ def encode_image(image: np.ndarray) -> str:
12
+ """Encodes a NumPy array image into a base64 JPEG string.
13
+
14
+ Args:
15
+ image: A NumPy array representing the image.
16
+
17
+ Returns:
18
+ A base64 encoded string prefixed with 'data:image/jpeg;base64,'.
19
+ """
20
+ pil_image = Image.fromarray(image)
21
+ buffer = BytesIO()
22
+ pil_image.save(buffer, format='jpeg')
23
+ return f"data:image/jpeg;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
24
+
25
+ def decode_image(base64_str: str) -> np.ndarray:
26
+ """Decodes a base64 encoded image string into a NumPy array.
27
+
28
+ Assumes the base64 string represents a valid image format (e.g., JPEG, PNG).
29
+
30
+ Args:
31
+ base64_str: The base64 encoded image string (may include prefix).
32
+
33
+ Returns:
34
+ A NumPy array representing the decoded image.
35
+ """
36
+ # Remove the prefix if it exists
37
+ if ',' in base64_str:
38
+ base64_str = base64_str.split(',', 1)[1]
39
+
40
+ # Decode the base64 string
41
+ image_data = base64.b64decode(base64_str)
42
+
43
+ # Convert the image data to a PIL Image
44
+ image = Image.open(io.BytesIO(image_data))
45
+
46
+ # Convert the PIL Image to a NumPy array
47
+ numpy_image = np.array(image)
48
+
49
+ return numpy_image
50
+
51
+ class Finding(BaseModel):
52
+ """Represents a detected finding in an image, including its label,
53
+ description, explanation, bounding box coordinates, and severity level.
54
+ """
55
+ label: str
56
+ description: str
57
+ explanation: str
58
+ bounding_box: tuple[int, int, int, int]
59
+ severity: int
60
+
61
+ @field_validator("bounding_box")
62
+ @classmethod
63
+ def validate_bounding_box(cls, value: tuple[int, int, int, int]):
64
+ """Validates that the bounding box coordinates are logically consistent."""
65
+ if len(value) != 4:
66
+ raise ValueError("Bounding box must be a tuple of 4 integers")
67
+ if value[0] >= value[2]:
68
+ raise ValueError("Bounding box x_min (index 0) must be less than x_max (index 2)")
69
+ if value[1] >= value[3]:
70
+ raise ValueError("Bounding box y_min (index 1) must be less than y_max (index 3)")
71
+ return value
72
+
73
+ class BoundingBox(BaseModel):
74
+ """Represents a bounding box with a label and explicit min/max coordinates. Assumess that the top left corner is the origin"""
75
+ label: str
76
+ x_min: int
77
+ y_min: int
78
+ x_max: int
79
+ y_max: int
80
+
81
+ @staticmethod
82
+ def from_finding(finding: Finding) -> 'BoundingBox':
83
+ """Creates a BoundingBox instance from a Finding instance."""
84
+ return BoundingBox(label=finding.label, x_min=finding.bounding_box[0], y_min=finding.bounding_box[1], x_max=finding.bounding_box[2], y_max=finding.bounding_box[3])
85
+
86
+ @staticmethod
87
+ def from_array(label: str, box: list[int]) -> 'BoundingBox':
88
+ """Creates a BoundingBox instance from a label and a list of coordinates."""
89
+ return BoundingBox(label=label, x_min=box[0], y_min=box[1], x_max=box[2], y_max=box[3])
90
+
91
+ def parse_json_response(out: str) -> list[dict]:
92
+ """Extracts and parses JSON content from a string.
93
+
94
+ Handles responses potentially wrapped in <output> tags or markdown code blocks (```json).
95
+
96
+ Args:
97
+ out: The input string potentially containing JSON.
98
+
99
+ Returns:
100
+ The parsed JSON object (list or dictionary).
101
+
102
+ Raises:
103
+ ValueError: If no valid JSON content is found.
104
+ """
105
+ start_prefix = "<output>"
106
+ end_postfix = "</output>"
107
+ start_index = out.find(start_prefix)
108
+ end_index = out.rfind(end_postfix)
109
+
110
+ if start_index == -1:
111
+ # try to load by finding ```json ``` markers
112
+ start_index = out.rfind("```json")
113
+ end_index = out.rfind("```")
114
+ if start_index == -1 or end_index == -1:
115
+ raise ValueError("No JSON found in response")
116
+ start_index += len("```json")
117
+ fixed = out[start_index:end_index]
118
+ print(f"fixed: {fixed}")
119
+ return json.loads(fixed)
120
+
121
+ start_index += len(start_prefix)
122
+ fixed = out[start_index:end_index]
123
+ fixed = fixed.strip()
124
+ if fixed.startswith("```json"):
125
+ start_index = fixed.find("[")
126
+ end_index = fixed.rfind("]")
127
+
128
+ fixed = fixed[start_index:end_index + 1]
129
+ return json.loads(fixed)
130
+
131
+
132
+ def parse_into_models(findings: list[dict]) -> list[Finding]:
133
+ """Parses and validates a list of dictionaries into a list of Finding models.
134
+
135
+ Args:
136
+ findings: A list of dictionaries, each representing a finding.
137
+
138
+ Returns:
139
+ A list of validated Finding model instances.
140
+ """
141
+ parsed = []
142
+ for box in findings:
143
+ model_finding = Finding.model_validate(box)
144
+ parsed.append(model_finding)
145
+ return parsed
146
+
147
+
148
+ def parse_all_safe(out: str) -> list[Finding] | None:
149
+ """Safely parses a string potentially containing JSON findings into Finding models.
150
+
151
+ Combines `parse_json_response` and `parse_into_models`, returning None on any parsing error.
152
+
153
+ Args:
154
+ out: The input string.
155
+
156
+ Returns:
157
+ A list of Finding models if parsing is successful, otherwise None.
158
+ """
159
+ try:
160
+ return parse_into_models(parse_json_response(out))
161
+ except Exception:
162
+ return None
163
+
164
+
165
+ def clamp(num: int | float, min_num: int | float = 0, max_num: int | float = 255) -> int | float:
166
+ """Clamps a number within a specified range [min_num, max_num]."""
167
+ return max(min_num, min(num, max_num))
168
+
169
+ def enlarge_boxes(image_shape: tuple[int, int], findings: list[Finding], factor: float = 1.1) -> list[Finding]:
170
+ """Enlarges the bounding boxes of findings by a given factor, clamping to image boundaries.
171
+
172
+ Args:
173
+ image_shape: A tuple (height, width) representing the image dimensions.
174
+ findings: A list of Finding objects.
175
+ factor: The factor by which to enlarge the boxes (e.g., 1.1 for 10% larger).
176
+
177
+ Returns:
178
+ A new list of Finding objects with adjusted bounding boxes.
179
+ """
180
+ adjusted = []
181
+ img_height, img_width = image_shape
182
+ for box in findings:
183
+ x_min_orig, y_min_orig, x_max_orig, y_max_orig = box.bounding_box
184
+ x_width = x_max_orig - x_min_orig
185
+ y_width = y_max_orig - y_min_orig
186
+
187
+ # Calculate the amount to adjust on each side
188
+ x_adjust = (x_width * (factor - 1)) / 2
189
+ y_adjust = (y_width * (factor - 1)) / 2
190
+
191
+ # Calculate new coordinates and clamp them
192
+ x_min = clamp(x_min_orig - x_adjust, 0, img_width)
193
+ y_min = clamp(y_min_orig - y_adjust, 0, img_height)
194
+ x_max = clamp(x_max_orig + x_adjust, 0, img_width)
195
+ y_max = clamp(y_max_orig + y_adjust, 0, img_height)
196
+
197
+ # Ensure coordinates remain valid integers if they were originally
198
+ adjusted_bbox = (int(round(x_min)), int(round(y_min)), int(round(x_max)), int(round(y_max)))
199
+
200
+ # Validate adjusted box before creating new Finding
201
+ try:
202
+ Finding.validate_bounding_box(adjusted_bbox)
203
+ adjusted.append(box.model_copy(update={'bounding_box': adjusted_bbox}))
204
+ except ValueError:
205
+ # If enlarging makes the box invalid (e.g., min >= max), keep the original
206
+ adjusted.append(box) # Or handle the error differently if needed
207
+
208
+ return adjusted
209
+
210
+ def change_box_format(shape: tuple[int, int, int], box: tuple[int, int, int, int]) -> tuple[float, float, float, float]:
211
+ """Normalizes bounding box coordinates from a 1000x1000 grid to the image dimensions.
212
+ This is only for gemini based models, as they returns coordinates normalized between 0-1000
213
+ Qwen based models don't need this
214
+ Assumes the input box coordinates are relative to a 1000x1000 grid.
215
+
216
+ Args:
217
+ shape: The shape of the target image (height, width, channels).
218
+ box: The bounding box tuple (x_min, y_min, x_max, y_max) in 1000x1000 coordinates.
219
+
220
+ Returns:
221
+ A tuple of normalized bounding box coordinates (x_min, y_min, x_max, y_max)
222
+ relative to the image dimensions.
223
+ """
224
+ y_height, x_width, _ = shape
225
+ # Normalize coordinates from 1000x1000 grid to image dimensions
226
+ x_min = (box[0] / 1000.0) * x_width
227
+ y_min = (box[1] / 1000.0) * y_height
228
+ x_max = (box[2] / 1000.0) * x_width
229
+ y_max = (box[3] / 1000.0) * y_height
230
+
231
+ return (x_min, y_min, x_max, y_max)
232
+
233
+ def normalize_findings_boxes(shape: tuple[int, int, int], findings: list[Finding]) -> list[Finding]:
234
+ """Normalizes the bounding boxes of all findings in a list.
235
+ This is only for gemini based models, as they returns coordinates normalized between 0-1000
236
+ Qwen based models don't need this
237
+
238
+ Modifies the findings list in-place.
239
+
240
+ Args:
241
+ shape: The shape of the target image (height, width, channels).
242
+ findings: A list of Finding objects whose bounding boxes need normalization.
243
+
244
+ Returns:
245
+ The list of Finding objects with normalized bounding boxes (modified in-place).
246
+ """
247
+ for finding in findings:
248
+ # Ensure the bounding box is a tuple before passing
249
+ current_box = tuple(finding.bounding_box)
250
+ finding.bounding_box = change_box_format(shape, current_box)
251
+ return findings
252
+
253
+ def change_box_format(shape, box):
254
+ y_width, x_width, _ = shape
255
+ # so apparently the bounding box always refers to a 1000x1000 grid
256
+ # so we need to normalize
257
+ # i assume that it has to do with the way their image embeddings work
258
+ x_min = (box[0] / 1000) * x_width
259
+ y_min = (box[1] / 1000) * y_width
260
+ x_max = (box[2] / 1000) * x_width
261
+ y_max = (box[3] / 1000) * y_width
262
+
263
+ return [x_min, y_min, x_max, y_max]
264
+
265
+ def normalize_findings_boxes(shape, findings):
266
+ for finding in findings:
267
+ finding.bounding_box = change_box_format(shape, finding.bounding_box)
268
+ return findings
269
+
270
+ def visualize_boxes(image, findings):
271
+ # Create a figure and axis
272
+ fig, ax = plt.subplots(1)
273
+ ax.imshow(image)
274
+
275
+ # Define a list of colors for the boxes
276
+ colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
277
+
278
+ for i, finding in enumerate(findings):
279
+ [x_min, y_min, x_max, y_max] = finding.bounding_box
280
+
281
+ # Select a color for the current box
282
+ color = colors[i % len(colors)]
283
+
284
+ rect = patches.Rectangle((x_min, y_min),
285
+ x_max - x_min,
286
+ y_max - y_min,
287
+ linewidth=2, edgecolor=color, facecolor='none')
288
+
289
+ ax.add_patch(rect)
290
+
291
+ # Print the whole finding and the color of its box
292
+ print(f"Finding {i+1} (Color: {color}):")
293
+ if (len(findings) == 0):
294
+ print("No findings")
295
+ # Set x-axis ticks every 2 units
296
+ #plt.xticks(np.arange(0, image.shape[1], 50)) # Start, Stop, Step
297
+ #plt.yticks(np.arange(0, image.shape[0], 50)) # Start, Stop, Step
298
+
299
+ plt.show()
300
+
301
+ def visualize_boxes_annotated(image: np.ndarray | Image.Image, boxes: list[BoundingBox]) -> np.ndarray:
302
+ """Draws bounding boxes with labels on an image and returns the annotated image as a NumPy array.
303
+
304
+ Args:
305
+ image: The input image (NumPy array or PIL Image).
306
+ boxes: A list of BoundingBox objects with coordinates relative to the image.
307
+
308
+ Returns:
309
+ A NumPy array representing the image with annotated bounding boxes.
310
+ """
311
+ if not isinstance(image, np.ndarray):
312
+ image = np.array(image)
313
+ # Create a figure and axis with high DPI
314
+ fig = plt.figure(dpi=300)
315
+ ax = plt.subplot(111)
316
+ ax.imshow(image)
317
+ ax.set_axis_off()
318
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
319
+
320
+ # Define a list of colors for the boxes
321
+ colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
322
+
323
+ for i, box in enumerate(boxes):
324
+ x_min = box.x_min
325
+ y_min = box.y_min
326
+ x_max = box.x_max
327
+ y_max = box.y_max
328
+ label = box.label
329
+
330
+ # Select a color for the current box
331
+ color = colors[i % len(colors)]
332
+
333
+ rect = patches.Rectangle((x_min, y_min),
334
+ x_max - x_min,
335
+ y_max - y_min,
336
+ linewidth=1, edgecolor=color, facecolor='none')
337
+
338
+ ax.add_patch(rect)
339
+
340
+ # Add label text above the box
341
+ ax.text(x_min, y_min-5, label, color=color, fontsize=10,
342
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
343
+
344
+ # Instead of displaying, save to numpy array
345
+ fig.canvas.draw()
346
+ data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
347
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))
348
+ # Convert RGBA to RGB
349
+ data = data[:, :, :3]
350
+ plt.close()
351
+ return data