Spaces:
Running
on
Zero
Running
on
Zero
hugohabicht01
commited on
Commit
·
dae4d1c
1
Parent(s):
335bcd6
automatically download sam weights
Browse files- blurnonymize.py +169 -63
- utils.py +1 -5
blurnonymize.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import
|
2 |
import traceback
|
3 |
from typing import Literal, Optional
|
4 |
|
@@ -7,7 +7,6 @@ import matplotlib.patches as patches
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
-
from pydantic import BaseModel
|
11 |
from sam2.build_sam import build_sam2
|
12 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
13 |
from utils import *
|
@@ -15,6 +14,7 @@ from utils import *
|
|
15 |
|
16 |
# --- Utility Functions (kept outside the class) ---
|
17 |
|
|
|
18 |
def blur_image(img: np.ndarray):
|
19 |
"""Applies Gaussian blur to an image."""
|
20 |
return cv2.GaussianBlur(img, (35, 35), 50)
|
@@ -26,13 +26,14 @@ def plot_polygon_mask(image: np.ndarray, polygons: list[list[tuple[int, int]]]):
|
|
26 |
"""
|
27 |
plt.imshow(image)
|
28 |
for polygon in polygons:
|
29 |
-
if not polygon:
|
|
|
30 |
polygon_array = np.array(polygon).reshape(-1, 2)
|
31 |
x, y = zip(*polygon_array)
|
32 |
x = list(x) + [x[0]]
|
33 |
y = list(y) + [y[0]]
|
34 |
-
plt.plot(x, y,
|
35 |
-
plt.axis(
|
36 |
plt.tight_layout()
|
37 |
plt.show()
|
38 |
|
@@ -41,12 +42,18 @@ def visualize_boxes(image, findings):
|
|
41 |
"""Visualizes bounding boxes on an image."""
|
42 |
fig, ax = plt.subplots(1)
|
43 |
ax.imshow(image)
|
44 |
-
colors = [
|
45 |
for i, finding in enumerate(findings):
|
46 |
[x_min, y_min, x_max, y_max] = finding.bounding_box
|
47 |
color = colors[i % len(colors)]
|
48 |
-
rect = patches.Rectangle(
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
ax.add_patch(rect)
|
51 |
print(f"Finding {i + 1} (Color: {color}):")
|
52 |
if not findings:
|
@@ -55,8 +62,10 @@ def visualize_boxes(image, findings):
|
|
55 |
plt.yticks(np.arange(0, image.shape[0], 50))
|
56 |
plt.show()
|
57 |
|
|
|
58 |
# --- SAM Visualization Helpers (kept outside the class) ---
|
59 |
|
|
|
60 |
def show_mask(mask, ax, random_color=False, borders=True):
|
61 |
"""Displays a single mask on a matplotlib axis."""
|
62 |
if random_color:
|
@@ -69,23 +78,54 @@ def show_mask(mask, ax, random_color=False, borders=True):
|
|
69 |
if borders:
|
70 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
71 |
# contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] # Optional smoothing
|
72 |
-
mask_image = cv2.drawContours(
|
|
|
|
|
73 |
ax.imshow(mask_image)
|
74 |
|
|
|
75 |
def show_points(coords, labels, ax, marker_size=375):
|
76 |
"""Displays points (positive/negative) on a matplotlib axis."""
|
77 |
pos_points = coords[labels == 1]
|
78 |
neg_points = coords[labels == 0]
|
79 |
-
ax.scatter(
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
def show_box(box, ax):
|
83 |
"""Displays a bounding box on a matplotlib axis."""
|
84 |
x0, y0 = box[0], box[1]
|
85 |
w, h = box[2] - box[0], box[3] - box[1]
|
86 |
-
ax.add_patch(
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
"""Displays multiple masks resulting from SAM prediction."""
|
90 |
for i, (mask, score) in enumerate(zip(masks, scores)):
|
91 |
plt.figure(figsize=(10, 10))
|
@@ -98,16 +138,48 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
|
|
98 |
show_box(box_coords, plt.gca())
|
99 |
if len(scores) > 1:
|
100 |
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
|
101 |
-
plt.axis(
|
102 |
plt.show()
|
103 |
|
104 |
|
105 |
# --- ImageBlurnonymizer Class ---
|
106 |
|
|
|
107 |
class ImageBlurnonymizer:
|
108 |
-
def __init__(self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
@staticmethod
|
113 |
def _smoothen_mask(mask: np.ndarray):
|
@@ -118,18 +190,18 @@ class ImageBlurnonymizer:
|
|
118 |
@staticmethod
|
119 |
def _mask_from_bbox(image_shape, bbox: tuple[int, int, int, int]):
|
120 |
"""Creates a simple rectangular mask from a bounding box."""
|
121 |
-
height, width, *_ = image_shape
|
122 |
xmin, ymin, xmax, ymax = bbox
|
123 |
mask = np.zeros((height, width), dtype=np.uint8)
|
124 |
mask[ymin:ymax, xmin:xmax] = 1
|
125 |
-
return mask
|
126 |
|
127 |
@staticmethod
|
128 |
def _apply_blur_mask(image: np.ndarray, mask: np.ndarray):
|
129 |
"""Applies a blur to an image based on a mask."""
|
130 |
-
if mask.ndim == 2:
|
131 |
-
|
132 |
-
blurred = blur_image(image)
|
133 |
return np.where(mask, blurred, image)
|
134 |
|
135 |
@staticmethod
|
@@ -138,19 +210,22 @@ class ImageBlurnonymizer:
|
|
138 |
try:
|
139 |
converted = (binary_mask * 255).astype(np.uint8)
|
140 |
# Use RETR_TREE to get hierarchy, CHAIN_APPROX_SIMPLE for efficiency
|
141 |
-
contours, _ = cv2.findContours(
|
|
|
|
|
142 |
polygons = []
|
143 |
for contour in contours:
|
144 |
approx_contour = cv2.approxPolyDP(contour, epsilon, True)
|
145 |
# Ensure points are converted correctly
|
146 |
-
polygon = [
|
|
|
|
|
147 |
polygons.append(polygon)
|
148 |
return polygons
|
149 |
except Exception as e:
|
150 |
print(f"An error occurred during polygon conversion: {e}")
|
151 |
print(traceback.format_exc())
|
152 |
-
return None
|
153 |
-
|
154 |
|
155 |
def get_segmentation_mask(self, image: np.ndarray, bbox: tuple[int, int, int, int]):
|
156 |
"""
|
@@ -159,9 +234,17 @@ class ImageBlurnonymizer:
|
|
159 |
Adds points within the bounding box to guide SAM towards the intended object (e.g., face)
|
160 |
and away from surrounding elements (e.g., hair).
|
161 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
x_min, y_min, x_max, y_max = bbox
|
163 |
x_width = x_max - x_min
|
164 |
-
y_height = y_max - y_min
|
165 |
|
166 |
# Handle cases where box dimensions are too small for third calculations
|
167 |
x_third = x_width // 3 if x_width >= 3 else 0
|
@@ -175,18 +258,17 @@ class ImageBlurnonymizer:
|
|
175 |
points.append([center_point[0], center_point[1] - y_third])
|
176 |
points.append([center_point[0], center_point[1] + y_third])
|
177 |
if x_third > 0:
|
178 |
-
|
179 |
-
|
180 |
|
181 |
# Ensure points are valid coordinates (e.g., non-negative)
|
182 |
points = [[max(0, p[0]), max(0, p[1])] for p in points]
|
183 |
|
184 |
-
|
185 |
self.predictor.set_image(image)
|
186 |
masks, scores, _ = self.predictor.predict(
|
187 |
-
box=np.array(bbox),
|
188 |
point_coords=np.array(points),
|
189 |
-
point_labels=np.ones(len(points)),
|
190 |
multimask_output=True,
|
191 |
)
|
192 |
|
@@ -197,11 +279,17 @@ class ImageBlurnonymizer:
|
|
197 |
|
198 |
return self._smoothen_mask(best_mask), best_score
|
199 |
|
200 |
-
def censor_image_blur(
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
202 |
"""
|
203 |
Censors an image by blurring regions identified in the raw_out string (LLM output).
|
204 |
"""
|
|
|
205 |
json_output = parse_json_response(raw_out)
|
206 |
# Ensure json_output is a list before passing to parse_into_models
|
207 |
if isinstance(json_output, dict):
|
@@ -209,75 +297,93 @@ class ImageBlurnonymizer:
|
|
209 |
elif isinstance(json_output, list):
|
210 |
findings_list = json_output
|
211 |
else:
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
215 |
|
216 |
-
parsed = parse_into_models(findings_list)
|
217 |
# Filter findings based on severity
|
218 |
filtered = [entry for entry in parsed if entry.severity > 0]
|
219 |
|
220 |
if verbose:
|
221 |
-
visualize_boxes(image, filtered)
|
222 |
|
223 |
masks = []
|
224 |
for finding in filtered:
|
225 |
-
bbox =
|
226 |
-
|
227 |
-
|
|
|
|
|
228 |
if verbose:
|
229 |
polygons = self._binary_mask_to_polygon(mask)
|
230 |
-
if polygons:
|
231 |
-
|
232 |
-
elif method ==
|
233 |
-
|
234 |
else:
|
235 |
-
|
236 |
-
|
|
|
|
|
237 |
|
238 |
masks.append(mask)
|
239 |
|
240 |
-
|
241 |
-
if masks: # Check if any masks were generated
|
242 |
# Combine masks: logical OR ensures any pixel in any mask is included
|
243 |
combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
|
244 |
for mask in masks:
|
245 |
# Ensure masks are boolean or uint8 for logical_or
|
246 |
-
combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
"""
|
254 |
Censors an image by blurring regions defined by a list of BoundingBox objects.
|
255 |
"""
|
|
|
256 |
masks = []
|
257 |
for box in boxes:
|
258 |
-
bbox_tuple = box.to_tuple()
|
259 |
-
if method ==
|
260 |
mask, _ = self.get_segmentation_mask(image, bbox_tuple)
|
261 |
if verbose:
|
262 |
polygons = self._binary_mask_to_polygon(mask)
|
263 |
if polygons:
|
264 |
plot_polygon_mask(image, polygons)
|
265 |
-
elif method ==
|
266 |
mask = self._mask_from_bbox(image.shape, bbox_tuple)
|
267 |
else:
|
268 |
-
|
269 |
-
|
|
|
|
|
270 |
|
271 |
masks.append(mask)
|
272 |
|
273 |
if masks:
|
274 |
combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
|
275 |
for mask in masks:
|
276 |
-
combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(
|
|
|
|
|
277 |
|
278 |
return self._apply_blur_mask(image, combined_mask)
|
279 |
return image
|
280 |
|
|
|
281 |
# Example Usage (Optional - keep outside class):
|
282 |
# if __name__ == '__main__':
|
283 |
# # Load an image
|
|
|
1 |
+
import os
|
2 |
import traceback
|
3 |
from typing import Literal, Optional
|
4 |
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import numpy as np
|
9 |
import torch
|
|
|
10 |
from sam2.build_sam import build_sam2
|
11 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
12 |
from utils import *
|
|
|
14 |
|
15 |
# --- Utility Functions (kept outside the class) ---
|
16 |
|
17 |
+
|
18 |
def blur_image(img: np.ndarray):
|
19 |
"""Applies Gaussian blur to an image."""
|
20 |
return cv2.GaussianBlur(img, (35, 35), 50)
|
|
|
26 |
"""
|
27 |
plt.imshow(image)
|
28 |
for polygon in polygons:
|
29 |
+
if not polygon:
|
30 |
+
continue # Skip empty polygons
|
31 |
polygon_array = np.array(polygon).reshape(-1, 2)
|
32 |
x, y = zip(*polygon_array)
|
33 |
x = list(x) + [x[0]]
|
34 |
y = list(y) + [y[0]]
|
35 |
+
plt.plot(x, y, "-r", linewidth=2)
|
36 |
+
plt.axis("off")
|
37 |
plt.tight_layout()
|
38 |
plt.show()
|
39 |
|
|
|
42 |
"""Visualizes bounding boxes on an image."""
|
43 |
fig, ax = plt.subplots(1)
|
44 |
ax.imshow(image)
|
45 |
+
colors = ["r", "g", "b", "c", "m", "y", "k"]
|
46 |
for i, finding in enumerate(findings):
|
47 |
[x_min, y_min, x_max, y_max] = finding.bounding_box
|
48 |
color = colors[i % len(colors)]
|
49 |
+
rect = patches.Rectangle(
|
50 |
+
(x_min, y_min),
|
51 |
+
x_max - x_min,
|
52 |
+
y_max - y_min,
|
53 |
+
linewidth=2,
|
54 |
+
edgecolor=color,
|
55 |
+
facecolor="none",
|
56 |
+
)
|
57 |
ax.add_patch(rect)
|
58 |
print(f"Finding {i + 1} (Color: {color}):")
|
59 |
if not findings:
|
|
|
62 |
plt.yticks(np.arange(0, image.shape[0], 50))
|
63 |
plt.show()
|
64 |
|
65 |
+
|
66 |
# --- SAM Visualization Helpers (kept outside the class) ---
|
67 |
|
68 |
+
|
69 |
def show_mask(mask, ax, random_color=False, borders=True):
|
70 |
"""Displays a single mask on a matplotlib axis."""
|
71 |
if random_color:
|
|
|
78 |
if borders:
|
79 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
80 |
# contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] # Optional smoothing
|
81 |
+
mask_image = cv2.drawContours(
|
82 |
+
mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
|
83 |
+
)
|
84 |
ax.imshow(mask_image)
|
85 |
|
86 |
+
|
87 |
def show_points(coords, labels, ax, marker_size=375):
|
88 |
"""Displays points (positive/negative) on a matplotlib axis."""
|
89 |
pos_points = coords[labels == 1]
|
90 |
neg_points = coords[labels == 0]
|
91 |
+
ax.scatter(
|
92 |
+
pos_points[:, 0],
|
93 |
+
pos_points[:, 1],
|
94 |
+
color="green",
|
95 |
+
marker="*",
|
96 |
+
s=marker_size,
|
97 |
+
edgecolor="white",
|
98 |
+
linewidth=1.25,
|
99 |
+
)
|
100 |
+
ax.scatter(
|
101 |
+
neg_points[:, 0],
|
102 |
+
neg_points[:, 1],
|
103 |
+
color="red",
|
104 |
+
marker="*",
|
105 |
+
s=marker_size,
|
106 |
+
edgecolor="white",
|
107 |
+
linewidth=1.25,
|
108 |
+
)
|
109 |
+
|
110 |
|
111 |
def show_box(box, ax):
|
112 |
"""Displays a bounding box on a matplotlib axis."""
|
113 |
x0, y0 = box[0], box[1]
|
114 |
w, h = box[2] - box[0], box[3] - box[1]
|
115 |
+
ax.add_patch(
|
116 |
+
plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
def show_masks(
|
121 |
+
image,
|
122 |
+
masks,
|
123 |
+
scores,
|
124 |
+
point_coords=None,
|
125 |
+
box_coords=None,
|
126 |
+
input_labels=None,
|
127 |
+
borders=True,
|
128 |
+
):
|
129 |
"""Displays multiple masks resulting from SAM prediction."""
|
130 |
for i, (mask, score) in enumerate(zip(masks, scores)):
|
131 |
plt.figure(figsize=(10, 10))
|
|
|
138 |
show_box(box_coords, plt.gca())
|
139 |
if len(scores) > 1:
|
140 |
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
|
141 |
+
plt.axis("off")
|
142 |
plt.show()
|
143 |
|
144 |
|
145 |
# --- ImageBlurnonymizer Class ---
|
146 |
|
147 |
+
|
148 |
class ImageBlurnonymizer:
|
149 |
+
def __init__(self):
|
150 |
+
self.predictor = None
|
151 |
+
self.device = None
|
152 |
+
self.model_cfg = None
|
153 |
+
|
154 |
+
self.checkpoint_name = "./sam2.1_hiera_small.pt"
|
155 |
+
self.model_cfg_name = "./sam2.1_hiera_s.yaml"
|
156 |
+
self.init_sam()
|
157 |
+
|
158 |
+
def init_sam(self, force=False):
|
159 |
+
# only initialize SAM if it hasn't been initialized yet
|
160 |
+
if self.predictor is not None and not force:
|
161 |
+
return
|
162 |
+
|
163 |
+
self.download_weights()
|
164 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
165 |
+
sam = build_sam2(self.model_cfg_name, self.checkpoint_name, device=self.device)
|
166 |
+
self.predictor = SAM2ImagePredictor(sam)
|
167 |
+
|
168 |
+
def download_weights(self):
|
169 |
+
# TODO: check whether these files already exist, if not, download them
|
170 |
+
# files names are in self.checkpoint_name and self.model_cfg_name
|
171 |
+
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt"
|
172 |
+
cfg_url = "https://raw.githubusercontent.com/facebookresearch/sam2/refs/heads/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml"
|
173 |
+
|
174 |
+
if not os.path.exists(self.checkpoint_name):
|
175 |
+
print(
|
176 |
+
f"Downloading checkpoint from {checkpoint_url} to {self.checkpoint_name}"
|
177 |
+
)
|
178 |
+
torch.hub.download_url_to_file(checkpoint_url, self.checkpoint_name)
|
179 |
+
|
180 |
+
if not os.path.exists(self.model_cfg_name):
|
181 |
+
print(f"Downloading config from {cfg_url} to {self.model_cfg_name}")
|
182 |
+
torch.hub.download_url_to_file(cfg_url, self.model_cfg_name)
|
183 |
|
184 |
@staticmethod
|
185 |
def _smoothen_mask(mask: np.ndarray):
|
|
|
190 |
@staticmethod
|
191 |
def _mask_from_bbox(image_shape, bbox: tuple[int, int, int, int]):
|
192 |
"""Creates a simple rectangular mask from a bounding box."""
|
193 |
+
height, width, *_ = image_shape # Allow for 2D or 3D shape tuple
|
194 |
xmin, ymin, xmax, ymax = bbox
|
195 |
mask = np.zeros((height, width), dtype=np.uint8)
|
196 |
mask[ymin:ymax, xmin:xmax] = 1
|
197 |
+
return mask # No need for np.array() conversion
|
198 |
|
199 |
@staticmethod
|
200 |
def _apply_blur_mask(image: np.ndarray, mask: np.ndarray):
|
201 |
"""Applies a blur to an image based on a mask."""
|
202 |
+
if mask.ndim == 2: # Ensure mask is 3-channel for broadcasting
|
203 |
+
mask = np.stack((mask,) * image.shape[2], axis=-1)
|
204 |
+
blurred = blur_image(image) # Use the utility function
|
205 |
return np.where(mask, blurred, image)
|
206 |
|
207 |
@staticmethod
|
|
|
210 |
try:
|
211 |
converted = (binary_mask * 255).astype(np.uint8)
|
212 |
# Use RETR_TREE to get hierarchy, CHAIN_APPROX_SIMPLE for efficiency
|
213 |
+
contours, _ = cv2.findContours(
|
214 |
+
converted, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
215 |
+
)
|
216 |
polygons = []
|
217 |
for contour in contours:
|
218 |
approx_contour = cv2.approxPolyDP(contour, epsilon, True)
|
219 |
# Ensure points are converted correctly
|
220 |
+
polygon = [
|
221 |
+
(int(point[0][0]), int(point[0][1])) for point in approx_contour
|
222 |
+
]
|
223 |
polygons.append(polygon)
|
224 |
return polygons
|
225 |
except Exception as e:
|
226 |
print(f"An error occurred during polygon conversion: {e}")
|
227 |
print(traceback.format_exc())
|
228 |
+
return None # Return None on error
|
|
|
229 |
|
230 |
def get_segmentation_mask(self, image: np.ndarray, bbox: tuple[int, int, int, int]):
|
231 |
"""
|
|
|
234 |
Adds points within the bounding box to guide SAM towards the intended object (e.g., face)
|
235 |
and away from surrounding elements (e.g., hair).
|
236 |
"""
|
237 |
+
|
238 |
+
if self.predictor is None:
|
239 |
+
raise Exception("[-] sam has not been initialized")
|
240 |
+
|
241 |
+
if torch.cuda.is_available() and self.device == "cpu":
|
242 |
+
# class instance was wrongly initialized to run on cpu, but gpu is avaiable
|
243 |
+
self.init_sam(force=True)
|
244 |
+
|
245 |
x_min, y_min, x_max, y_max = bbox
|
246 |
x_width = x_max - x_min
|
247 |
+
y_height = y_max - y_min # Corrected variable name
|
248 |
|
249 |
# Handle cases where box dimensions are too small for third calculations
|
250 |
x_third = x_width // 3 if x_width >= 3 else 0
|
|
|
258 |
points.append([center_point[0], center_point[1] - y_third])
|
259 |
points.append([center_point[0], center_point[1] + y_third])
|
260 |
if x_third > 0:
|
261 |
+
points.append([center_point[0] + x_third, center_point[1]])
|
262 |
+
points.append([center_point[0] - x_third, center_point[1]])
|
263 |
|
264 |
# Ensure points are valid coordinates (e.g., non-negative)
|
265 |
points = [[max(0, p[0]), max(0, p[1])] for p in points]
|
266 |
|
|
|
267 |
self.predictor.set_image(image)
|
268 |
masks, scores, _ = self.predictor.predict(
|
269 |
+
box=np.array(bbox), # Predictor might expect numpy array
|
270 |
point_coords=np.array(points),
|
271 |
+
point_labels=np.ones(len(points)), # Label 1 for inclusion
|
272 |
multimask_output=True,
|
273 |
)
|
274 |
|
|
|
279 |
|
280 |
return self._smoothen_mask(best_mask), best_score
|
281 |
|
282 |
+
def censor_image_blur(
|
283 |
+
self,
|
284 |
+
image: np.ndarray,
|
285 |
+
raw_out: str,
|
286 |
+
method: Optional[Literal["segmentation", "bbox"]] = "segmentation",
|
287 |
+
verbose=False,
|
288 |
+
):
|
289 |
"""
|
290 |
Censors an image by blurring regions identified in the raw_out string (LLM output).
|
291 |
"""
|
292 |
+
self.init_sam()
|
293 |
json_output = parse_json_response(raw_out)
|
294 |
# Ensure json_output is a list before passing to parse_into_models
|
295 |
if isinstance(json_output, dict):
|
|
|
297 |
elif isinstance(json_output, list):
|
298 |
findings_list = json_output
|
299 |
else:
|
300 |
+
# Handle unexpected type or raise an error
|
301 |
+
print(
|
302 |
+
f"Warning: Unexpected output type from parse_json_response: {type(json_output)}"
|
303 |
+
)
|
304 |
+
findings_list = []
|
305 |
|
306 |
+
parsed = parse_into_models(findings_list) # type: ignore
|
307 |
# Filter findings based on severity
|
308 |
filtered = [entry for entry in parsed if entry.severity > 0]
|
309 |
|
310 |
if verbose:
|
311 |
+
visualize_boxes(image, filtered) # Use external visualization
|
312 |
|
313 |
masks = []
|
314 |
for finding in filtered:
|
315 |
+
bbox = (
|
316 |
+
finding.bounding_box
|
317 |
+
) # Assuming finding has a 'bounding_box' attribute
|
318 |
+
if method == "segmentation":
|
319 |
+
mask, _ = self.get_segmentation_mask(image, bbox) # Use instance method
|
320 |
if verbose:
|
321 |
polygons = self._binary_mask_to_polygon(mask)
|
322 |
+
if polygons: # Check if polygon conversion was successful
|
323 |
+
plot_polygon_mask(image, polygons) # Use external visualization
|
324 |
+
elif method == "bbox":
|
325 |
+
mask = self._mask_from_bbox(image.shape, bbox) # Use static method
|
326 |
else:
|
327 |
+
print(
|
328 |
+
f"Warning: Unknown method '{method}'. Defaulting to no mask for this finding."
|
329 |
+
)
|
330 |
+
continue # Skip if method is invalid
|
331 |
|
332 |
masks.append(mask)
|
333 |
|
334 |
+
if masks: # Check if any masks were generated
|
|
|
335 |
# Combine masks: logical OR ensures any pixel in any mask is included
|
336 |
combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
|
337 |
for mask in masks:
|
338 |
# Ensure masks are boolean or uint8 for logical_or
|
339 |
+
combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(
|
340 |
+
np.uint8
|
341 |
+
)
|
342 |
+
|
343 |
+
return self._apply_blur_mask(image, combined_mask) # Use static method
|
344 |
+
return image # Return original image if no masks
|
345 |
+
|
346 |
+
def censor_image_blur_easy(
|
347 |
+
self,
|
348 |
+
image: np.ndarray,
|
349 |
+
boxes: list[BoundingBox],
|
350 |
+
method: Optional[Literal["segmentation", "bbox"]] = "segmentation",
|
351 |
+
verbose=False,
|
352 |
+
):
|
353 |
"""
|
354 |
Censors an image by blurring regions defined by a list of BoundingBox objects.
|
355 |
"""
|
356 |
+
self.init_sam()
|
357 |
masks = []
|
358 |
for box in boxes:
|
359 |
+
bbox_tuple = box.to_tuple() # Convert BoundingBox object to tuple
|
360 |
+
if method == "segmentation":
|
361 |
mask, _ = self.get_segmentation_mask(image, bbox_tuple)
|
362 |
if verbose:
|
363 |
polygons = self._binary_mask_to_polygon(mask)
|
364 |
if polygons:
|
365 |
plot_polygon_mask(image, polygons)
|
366 |
+
elif method == "bbox":
|
367 |
mask = self._mask_from_bbox(image.shape, bbox_tuple)
|
368 |
else:
|
369 |
+
print(
|
370 |
+
f"Warning: Unknown method '{method}'. Defaulting to no mask for this box."
|
371 |
+
)
|
372 |
+
continue
|
373 |
|
374 |
masks.append(mask)
|
375 |
|
376 |
if masks:
|
377 |
combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
|
378 |
for mask in masks:
|
379 |
+
combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(
|
380 |
+
np.uint8
|
381 |
+
)
|
382 |
|
383 |
return self._apply_blur_mask(image, combined_mask)
|
384 |
return image
|
385 |
|
386 |
+
|
387 |
# Example Usage (Optional - keep outside class):
|
388 |
# if __name__ == '__main__':
|
389 |
# # Load an image
|
utils.py
CHANGED
@@ -138,11 +138,7 @@ def parse_into_models(findings: list[dict]) -> list[Finding]:
|
|
138 |
Returns:
|
139 |
A list of validated Finding model instances.
|
140 |
"""
|
141 |
-
|
142 |
-
for box in findings:
|
143 |
-
model_finding = Finding.model_validate(box)
|
144 |
-
parsed.append(model_finding)
|
145 |
-
return parsed
|
146 |
|
147 |
|
148 |
def parse_all_safe(out: str) -> list[Finding] | None:
|
|
|
138 |
Returns:
|
139 |
A list of validated Finding model instances.
|
140 |
"""
|
141 |
+
return [Finding.model_validate(box) for box in findings]
|
|
|
|
|
|
|
|
|
142 |
|
143 |
|
144 |
def parse_all_safe(out: str) -> list[Finding] | None:
|