hugohabicht01 commited on
Commit
dae4d1c
·
1 Parent(s): 335bcd6

automatically download sam weights

Browse files
Files changed (2) hide show
  1. blurnonymize.py +169 -63
  2. utils.py +1 -5
blurnonymize.py CHANGED
@@ -1,4 +1,4 @@
1
- import json
2
  import traceback
3
  from typing import Literal, Optional
4
 
@@ -7,7 +7,6 @@ import matplotlib.patches as patches
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import torch
10
- from pydantic import BaseModel
11
  from sam2.build_sam import build_sam2
12
  from sam2.sam2_image_predictor import SAM2ImagePredictor
13
  from utils import *
@@ -15,6 +14,7 @@ from utils import *
15
 
16
  # --- Utility Functions (kept outside the class) ---
17
 
 
18
  def blur_image(img: np.ndarray):
19
  """Applies Gaussian blur to an image."""
20
  return cv2.GaussianBlur(img, (35, 35), 50)
@@ -26,13 +26,14 @@ def plot_polygon_mask(image: np.ndarray, polygons: list[list[tuple[int, int]]]):
26
  """
27
  plt.imshow(image)
28
  for polygon in polygons:
29
- if not polygon: continue # Skip empty polygons
 
30
  polygon_array = np.array(polygon).reshape(-1, 2)
31
  x, y = zip(*polygon_array)
32
  x = list(x) + [x[0]]
33
  y = list(y) + [y[0]]
34
- plt.plot(x, y, '-r', linewidth=2)
35
- plt.axis('off')
36
  plt.tight_layout()
37
  plt.show()
38
 
@@ -41,12 +42,18 @@ def visualize_boxes(image, findings):
41
  """Visualizes bounding boxes on an image."""
42
  fig, ax = plt.subplots(1)
43
  ax.imshow(image)
44
- colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
45
  for i, finding in enumerate(findings):
46
  [x_min, y_min, x_max, y_max] = finding.bounding_box
47
  color = colors[i % len(colors)]
48
- rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2, edgecolor=color,
49
- facecolor='none')
 
 
 
 
 
 
50
  ax.add_patch(rect)
51
  print(f"Finding {i + 1} (Color: {color}):")
52
  if not findings:
@@ -55,8 +62,10 @@ def visualize_boxes(image, findings):
55
  plt.yticks(np.arange(0, image.shape[0], 50))
56
  plt.show()
57
 
 
58
  # --- SAM Visualization Helpers (kept outside the class) ---
59
 
 
60
  def show_mask(mask, ax, random_color=False, borders=True):
61
  """Displays a single mask on a matplotlib axis."""
62
  if random_color:
@@ -69,23 +78,54 @@ def show_mask(mask, ax, random_color=False, borders=True):
69
  if borders:
70
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
71
  # contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] # Optional smoothing
72
- mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
 
 
73
  ax.imshow(mask_image)
74
 
 
75
  def show_points(coords, labels, ax, marker_size=375):
76
  """Displays points (positive/negative) on a matplotlib axis."""
77
  pos_points = coords[labels == 1]
78
  neg_points = coords[labels == 0]
79
- ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
80
- ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def show_box(box, ax):
83
  """Displays a bounding box on a matplotlib axis."""
84
  x0, y0 = box[0], box[1]
85
  w, h = box[2] - box[0], box[3] - box[1]
86
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
87
-
88
- def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
 
 
 
 
 
 
 
 
 
 
 
89
  """Displays multiple masks resulting from SAM prediction."""
90
  for i, (mask, score) in enumerate(zip(masks, scores)):
91
  plt.figure(figsize=(10, 10))
@@ -98,16 +138,48 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
98
  show_box(box_coords, plt.gca())
99
  if len(scores) > 1:
100
  plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
101
- plt.axis('off')
102
  plt.show()
103
 
104
 
105
  # --- ImageBlurnonymizer Class ---
106
 
 
107
  class ImageBlurnonymizer:
108
- def __init__(self, checkpoint="./sam2.1_hiera_large.pt", model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
110
- self.predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint, device=self.device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  @staticmethod
113
  def _smoothen_mask(mask: np.ndarray):
@@ -118,18 +190,18 @@ class ImageBlurnonymizer:
118
  @staticmethod
119
  def _mask_from_bbox(image_shape, bbox: tuple[int, int, int, int]):
120
  """Creates a simple rectangular mask from a bounding box."""
121
- height, width, *_ = image_shape # Allow for 2D or 3D shape tuple
122
  xmin, ymin, xmax, ymax = bbox
123
  mask = np.zeros((height, width), dtype=np.uint8)
124
  mask[ymin:ymax, xmin:xmax] = 1
125
- return mask # No need for np.array() conversion
126
 
127
  @staticmethod
128
  def _apply_blur_mask(image: np.ndarray, mask: np.ndarray):
129
  """Applies a blur to an image based on a mask."""
130
- if mask.ndim == 2: # Ensure mask is 3-channel for broadcasting
131
- mask = np.stack((mask,) * image.shape[2], axis=-1)
132
- blurred = blur_image(image) # Use the utility function
133
  return np.where(mask, blurred, image)
134
 
135
  @staticmethod
@@ -138,19 +210,22 @@ class ImageBlurnonymizer:
138
  try:
139
  converted = (binary_mask * 255).astype(np.uint8)
140
  # Use RETR_TREE to get hierarchy, CHAIN_APPROX_SIMPLE for efficiency
141
- contours, _ = cv2.findContours(converted, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
142
  polygons = []
143
  for contour in contours:
144
  approx_contour = cv2.approxPolyDP(contour, epsilon, True)
145
  # Ensure points are converted correctly
146
- polygon = [(int(point[0][0]), int(point[0][1])) for point in approx_contour]
 
 
147
  polygons.append(polygon)
148
  return polygons
149
  except Exception as e:
150
  print(f"An error occurred during polygon conversion: {e}")
151
  print(traceback.format_exc())
152
- return None # Return None on error
153
-
154
 
155
  def get_segmentation_mask(self, image: np.ndarray, bbox: tuple[int, int, int, int]):
156
  """
@@ -159,9 +234,17 @@ class ImageBlurnonymizer:
159
  Adds points within the bounding box to guide SAM towards the intended object (e.g., face)
160
  and away from surrounding elements (e.g., hair).
161
  """
 
 
 
 
 
 
 
 
162
  x_min, y_min, x_max, y_max = bbox
163
  x_width = x_max - x_min
164
- y_height = y_max - y_min # Corrected variable name
165
 
166
  # Handle cases where box dimensions are too small for third calculations
167
  x_third = x_width // 3 if x_width >= 3 else 0
@@ -175,18 +258,17 @@ class ImageBlurnonymizer:
175
  points.append([center_point[0], center_point[1] - y_third])
176
  points.append([center_point[0], center_point[1] + y_third])
177
  if x_third > 0:
178
- points.append([center_point[0] + x_third, center_point[1]])
179
- points.append([center_point[0] - x_third, center_point[1]])
180
 
181
  # Ensure points are valid coordinates (e.g., non-negative)
182
  points = [[max(0, p[0]), max(0, p[1])] for p in points]
183
 
184
-
185
  self.predictor.set_image(image)
186
  masks, scores, _ = self.predictor.predict(
187
- box=np.array(bbox), # Predictor might expect numpy array
188
  point_coords=np.array(points),
189
- point_labels=np.ones(len(points)), # Label 1 for inclusion
190
  multimask_output=True,
191
  )
192
 
@@ -197,11 +279,17 @@ class ImageBlurnonymizer:
197
 
198
  return self._smoothen_mask(best_mask), best_score
199
 
200
- def censor_image_blur(self, image: np.ndarray, raw_out: str,
201
- method: Optional[Literal['segmentation', 'bbox']] = 'segmentation', verbose=False):
 
 
 
 
 
202
  """
203
  Censors an image by blurring regions identified in the raw_out string (LLM output).
204
  """
 
205
  json_output = parse_json_response(raw_out)
206
  # Ensure json_output is a list before passing to parse_into_models
207
  if isinstance(json_output, dict):
@@ -209,75 +297,93 @@ class ImageBlurnonymizer:
209
  elif isinstance(json_output, list):
210
  findings_list = json_output
211
  else:
212
- # Handle unexpected type or raise an error
213
- print(f"Warning: Unexpected output type from parse_json_response: {type(json_output)}")
214
- findings_list = []
 
 
215
 
216
- parsed = parse_into_models(findings_list)
217
  # Filter findings based on severity
218
  filtered = [entry for entry in parsed if entry.severity > 0]
219
 
220
  if verbose:
221
- visualize_boxes(image, filtered) # Use external visualization
222
 
223
  masks = []
224
  for finding in filtered:
225
- bbox = finding.bounding_box # Assuming finding has a 'bounding_box' attribute
226
- if method == 'segmentation':
227
- mask, _ = self.get_segmentation_mask(image, bbox) # Use instance method
 
 
228
  if verbose:
229
  polygons = self._binary_mask_to_polygon(mask)
230
- if polygons: # Check if polygon conversion was successful
231
- plot_polygon_mask(image, polygons) # Use external visualization
232
- elif method == 'bbox':
233
- mask = self._mask_from_bbox(image.shape, bbox) # Use static method
234
  else:
235
- print(f"Warning: Unknown method '{method}'. Defaulting to no mask for this finding.")
236
- continue # Skip if method is invalid
 
 
237
 
238
  masks.append(mask)
239
 
240
-
241
- if masks: # Check if any masks were generated
242
  # Combine masks: logical OR ensures any pixel in any mask is included
243
  combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
244
  for mask in masks:
245
  # Ensure masks are boolean or uint8 for logical_or
246
- combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(np.uint8)
247
-
248
- return self._apply_blur_mask(image, combined_mask) # Use static method
249
- return image # Return original image if no masks
250
-
251
- def censor_image_blur_easy(self, image: np.ndarray, boxes: list[BoundingBox],
252
- method: Optional[Literal['segmentation', 'bbox']] = 'segmentation', verbose=False):
 
 
 
 
 
 
 
253
  """
254
  Censors an image by blurring regions defined by a list of BoundingBox objects.
255
  """
 
256
  masks = []
257
  for box in boxes:
258
- bbox_tuple = box.to_tuple() # Convert BoundingBox object to tuple
259
- if method == 'segmentation':
260
  mask, _ = self.get_segmentation_mask(image, bbox_tuple)
261
  if verbose:
262
  polygons = self._binary_mask_to_polygon(mask)
263
  if polygons:
264
  plot_polygon_mask(image, polygons)
265
- elif method == 'bbox':
266
  mask = self._mask_from_bbox(image.shape, bbox_tuple)
267
  else:
268
- print(f"Warning: Unknown method '{method}'. Defaulting to no mask for this box.")
269
- continue
 
 
270
 
271
  masks.append(mask)
272
 
273
  if masks:
274
  combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
275
  for mask in masks:
276
- combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(np.uint8)
 
 
277
 
278
  return self._apply_blur_mask(image, combined_mask)
279
  return image
280
 
 
281
  # Example Usage (Optional - keep outside class):
282
  # if __name__ == '__main__':
283
  # # Load an image
 
1
+ import os
2
  import traceback
3
  from typing import Literal, Optional
4
 
 
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import torch
 
10
  from sam2.build_sam import build_sam2
11
  from sam2.sam2_image_predictor import SAM2ImagePredictor
12
  from utils import *
 
14
 
15
  # --- Utility Functions (kept outside the class) ---
16
 
17
+
18
  def blur_image(img: np.ndarray):
19
  """Applies Gaussian blur to an image."""
20
  return cv2.GaussianBlur(img, (35, 35), 50)
 
26
  """
27
  plt.imshow(image)
28
  for polygon in polygons:
29
+ if not polygon:
30
+ continue # Skip empty polygons
31
  polygon_array = np.array(polygon).reshape(-1, 2)
32
  x, y = zip(*polygon_array)
33
  x = list(x) + [x[0]]
34
  y = list(y) + [y[0]]
35
+ plt.plot(x, y, "-r", linewidth=2)
36
+ plt.axis("off")
37
  plt.tight_layout()
38
  plt.show()
39
 
 
42
  """Visualizes bounding boxes on an image."""
43
  fig, ax = plt.subplots(1)
44
  ax.imshow(image)
45
+ colors = ["r", "g", "b", "c", "m", "y", "k"]
46
  for i, finding in enumerate(findings):
47
  [x_min, y_min, x_max, y_max] = finding.bounding_box
48
  color = colors[i % len(colors)]
49
+ rect = patches.Rectangle(
50
+ (x_min, y_min),
51
+ x_max - x_min,
52
+ y_max - y_min,
53
+ linewidth=2,
54
+ edgecolor=color,
55
+ facecolor="none",
56
+ )
57
  ax.add_patch(rect)
58
  print(f"Finding {i + 1} (Color: {color}):")
59
  if not findings:
 
62
  plt.yticks(np.arange(0, image.shape[0], 50))
63
  plt.show()
64
 
65
+
66
  # --- SAM Visualization Helpers (kept outside the class) ---
67
 
68
+
69
  def show_mask(mask, ax, random_color=False, borders=True):
70
  """Displays a single mask on a matplotlib axis."""
71
  if random_color:
 
78
  if borders:
79
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
80
  # contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] # Optional smoothing
81
+ mask_image = cv2.drawContours(
82
+ mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
83
+ )
84
  ax.imshow(mask_image)
85
 
86
+
87
  def show_points(coords, labels, ax, marker_size=375):
88
  """Displays points (positive/negative) on a matplotlib axis."""
89
  pos_points = coords[labels == 1]
90
  neg_points = coords[labels == 0]
91
+ ax.scatter(
92
+ pos_points[:, 0],
93
+ pos_points[:, 1],
94
+ color="green",
95
+ marker="*",
96
+ s=marker_size,
97
+ edgecolor="white",
98
+ linewidth=1.25,
99
+ )
100
+ ax.scatter(
101
+ neg_points[:, 0],
102
+ neg_points[:, 1],
103
+ color="red",
104
+ marker="*",
105
+ s=marker_size,
106
+ edgecolor="white",
107
+ linewidth=1.25,
108
+ )
109
+
110
 
111
  def show_box(box, ax):
112
  """Displays a bounding box on a matplotlib axis."""
113
  x0, y0 = box[0], box[1]
114
  w, h = box[2] - box[0], box[3] - box[1]
115
+ ax.add_patch(
116
+ plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
117
+ )
118
+
119
+
120
+ def show_masks(
121
+ image,
122
+ masks,
123
+ scores,
124
+ point_coords=None,
125
+ box_coords=None,
126
+ input_labels=None,
127
+ borders=True,
128
+ ):
129
  """Displays multiple masks resulting from SAM prediction."""
130
  for i, (mask, score) in enumerate(zip(masks, scores)):
131
  plt.figure(figsize=(10, 10))
 
138
  show_box(box_coords, plt.gca())
139
  if len(scores) > 1:
140
  plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
141
+ plt.axis("off")
142
  plt.show()
143
 
144
 
145
  # --- ImageBlurnonymizer Class ---
146
 
147
+
148
  class ImageBlurnonymizer:
149
+ def __init__(self):
150
+ self.predictor = None
151
+ self.device = None
152
+ self.model_cfg = None
153
+
154
+ self.checkpoint_name = "./sam2.1_hiera_small.pt"
155
+ self.model_cfg_name = "./sam2.1_hiera_s.yaml"
156
+ self.init_sam()
157
+
158
+ def init_sam(self, force=False):
159
+ # only initialize SAM if it hasn't been initialized yet
160
+ if self.predictor is not None and not force:
161
+ return
162
+
163
+ self.download_weights()
164
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
165
+ sam = build_sam2(self.model_cfg_name, self.checkpoint_name, device=self.device)
166
+ self.predictor = SAM2ImagePredictor(sam)
167
+
168
+ def download_weights(self):
169
+ # TODO: check whether these files already exist, if not, download them
170
+ # files names are in self.checkpoint_name and self.model_cfg_name
171
+ checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt"
172
+ cfg_url = "https://raw.githubusercontent.com/facebookresearch/sam2/refs/heads/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml"
173
+
174
+ if not os.path.exists(self.checkpoint_name):
175
+ print(
176
+ f"Downloading checkpoint from {checkpoint_url} to {self.checkpoint_name}"
177
+ )
178
+ torch.hub.download_url_to_file(checkpoint_url, self.checkpoint_name)
179
+
180
+ if not os.path.exists(self.model_cfg_name):
181
+ print(f"Downloading config from {cfg_url} to {self.model_cfg_name}")
182
+ torch.hub.download_url_to_file(cfg_url, self.model_cfg_name)
183
 
184
  @staticmethod
185
  def _smoothen_mask(mask: np.ndarray):
 
190
  @staticmethod
191
  def _mask_from_bbox(image_shape, bbox: tuple[int, int, int, int]):
192
  """Creates a simple rectangular mask from a bounding box."""
193
+ height, width, *_ = image_shape # Allow for 2D or 3D shape tuple
194
  xmin, ymin, xmax, ymax = bbox
195
  mask = np.zeros((height, width), dtype=np.uint8)
196
  mask[ymin:ymax, xmin:xmax] = 1
197
+ return mask # No need for np.array() conversion
198
 
199
  @staticmethod
200
  def _apply_blur_mask(image: np.ndarray, mask: np.ndarray):
201
  """Applies a blur to an image based on a mask."""
202
+ if mask.ndim == 2: # Ensure mask is 3-channel for broadcasting
203
+ mask = np.stack((mask,) * image.shape[2], axis=-1)
204
+ blurred = blur_image(image) # Use the utility function
205
  return np.where(mask, blurred, image)
206
 
207
  @staticmethod
 
210
  try:
211
  converted = (binary_mask * 255).astype(np.uint8)
212
  # Use RETR_TREE to get hierarchy, CHAIN_APPROX_SIMPLE for efficiency
213
+ contours, _ = cv2.findContours(
214
+ converted, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
215
+ )
216
  polygons = []
217
  for contour in contours:
218
  approx_contour = cv2.approxPolyDP(contour, epsilon, True)
219
  # Ensure points are converted correctly
220
+ polygon = [
221
+ (int(point[0][0]), int(point[0][1])) for point in approx_contour
222
+ ]
223
  polygons.append(polygon)
224
  return polygons
225
  except Exception as e:
226
  print(f"An error occurred during polygon conversion: {e}")
227
  print(traceback.format_exc())
228
+ return None # Return None on error
 
229
 
230
  def get_segmentation_mask(self, image: np.ndarray, bbox: tuple[int, int, int, int]):
231
  """
 
234
  Adds points within the bounding box to guide SAM towards the intended object (e.g., face)
235
  and away from surrounding elements (e.g., hair).
236
  """
237
+
238
+ if self.predictor is None:
239
+ raise Exception("[-] sam has not been initialized")
240
+
241
+ if torch.cuda.is_available() and self.device == "cpu":
242
+ # class instance was wrongly initialized to run on cpu, but gpu is avaiable
243
+ self.init_sam(force=True)
244
+
245
  x_min, y_min, x_max, y_max = bbox
246
  x_width = x_max - x_min
247
+ y_height = y_max - y_min # Corrected variable name
248
 
249
  # Handle cases where box dimensions are too small for third calculations
250
  x_third = x_width // 3 if x_width >= 3 else 0
 
258
  points.append([center_point[0], center_point[1] - y_third])
259
  points.append([center_point[0], center_point[1] + y_third])
260
  if x_third > 0:
261
+ points.append([center_point[0] + x_third, center_point[1]])
262
+ points.append([center_point[0] - x_third, center_point[1]])
263
 
264
  # Ensure points are valid coordinates (e.g., non-negative)
265
  points = [[max(0, p[0]), max(0, p[1])] for p in points]
266
 
 
267
  self.predictor.set_image(image)
268
  masks, scores, _ = self.predictor.predict(
269
+ box=np.array(bbox), # Predictor might expect numpy array
270
  point_coords=np.array(points),
271
+ point_labels=np.ones(len(points)), # Label 1 for inclusion
272
  multimask_output=True,
273
  )
274
 
 
279
 
280
  return self._smoothen_mask(best_mask), best_score
281
 
282
+ def censor_image_blur(
283
+ self,
284
+ image: np.ndarray,
285
+ raw_out: str,
286
+ method: Optional[Literal["segmentation", "bbox"]] = "segmentation",
287
+ verbose=False,
288
+ ):
289
  """
290
  Censors an image by blurring regions identified in the raw_out string (LLM output).
291
  """
292
+ self.init_sam()
293
  json_output = parse_json_response(raw_out)
294
  # Ensure json_output is a list before passing to parse_into_models
295
  if isinstance(json_output, dict):
 
297
  elif isinstance(json_output, list):
298
  findings_list = json_output
299
  else:
300
+ # Handle unexpected type or raise an error
301
+ print(
302
+ f"Warning: Unexpected output type from parse_json_response: {type(json_output)}"
303
+ )
304
+ findings_list = []
305
 
306
+ parsed = parse_into_models(findings_list) # type: ignore
307
  # Filter findings based on severity
308
  filtered = [entry for entry in parsed if entry.severity > 0]
309
 
310
  if verbose:
311
+ visualize_boxes(image, filtered) # Use external visualization
312
 
313
  masks = []
314
  for finding in filtered:
315
+ bbox = (
316
+ finding.bounding_box
317
+ ) # Assuming finding has a 'bounding_box' attribute
318
+ if method == "segmentation":
319
+ mask, _ = self.get_segmentation_mask(image, bbox) # Use instance method
320
  if verbose:
321
  polygons = self._binary_mask_to_polygon(mask)
322
+ if polygons: # Check if polygon conversion was successful
323
+ plot_polygon_mask(image, polygons) # Use external visualization
324
+ elif method == "bbox":
325
+ mask = self._mask_from_bbox(image.shape, bbox) # Use static method
326
  else:
327
+ print(
328
+ f"Warning: Unknown method '{method}'. Defaulting to no mask for this finding."
329
+ )
330
+ continue # Skip if method is invalid
331
 
332
  masks.append(mask)
333
 
334
+ if masks: # Check if any masks were generated
 
335
  # Combine masks: logical OR ensures any pixel in any mask is included
336
  combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
337
  for mask in masks:
338
  # Ensure masks are boolean or uint8 for logical_or
339
+ combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(
340
+ np.uint8
341
+ )
342
+
343
+ return self._apply_blur_mask(image, combined_mask) # Use static method
344
+ return image # Return original image if no masks
345
+
346
+ def censor_image_blur_easy(
347
+ self,
348
+ image: np.ndarray,
349
+ boxes: list[BoundingBox],
350
+ method: Optional[Literal["segmentation", "bbox"]] = "segmentation",
351
+ verbose=False,
352
+ ):
353
  """
354
  Censors an image by blurring regions defined by a list of BoundingBox objects.
355
  """
356
+ self.init_sam()
357
  masks = []
358
  for box in boxes:
359
+ bbox_tuple = box.to_tuple() # Convert BoundingBox object to tuple
360
+ if method == "segmentation":
361
  mask, _ = self.get_segmentation_mask(image, bbox_tuple)
362
  if verbose:
363
  polygons = self._binary_mask_to_polygon(mask)
364
  if polygons:
365
  plot_polygon_mask(image, polygons)
366
+ elif method == "bbox":
367
  mask = self._mask_from_bbox(image.shape, bbox_tuple)
368
  else:
369
+ print(
370
+ f"Warning: Unknown method '{method}'. Defaulting to no mask for this box."
371
+ )
372
+ continue
373
 
374
  masks.append(mask)
375
 
376
  if masks:
377
  combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
378
  for mask in masks:
379
+ combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype(
380
+ np.uint8
381
+ )
382
 
383
  return self._apply_blur_mask(image, combined_mask)
384
  return image
385
 
386
+
387
  # Example Usage (Optional - keep outside class):
388
  # if __name__ == '__main__':
389
  # # Load an image
utils.py CHANGED
@@ -138,11 +138,7 @@ def parse_into_models(findings: list[dict]) -> list[Finding]:
138
  Returns:
139
  A list of validated Finding model instances.
140
  """
141
- parsed = []
142
- for box in findings:
143
- model_finding = Finding.model_validate(box)
144
- parsed.append(model_finding)
145
- return parsed
146
 
147
 
148
  def parse_all_safe(out: str) -> list[Finding] | None:
 
138
  Returns:
139
  A list of validated Finding model instances.
140
  """
141
+ return [Finding.model_validate(box) for box in findings]
 
 
 
 
142
 
143
 
144
  def parse_all_safe(out: str) -> list[Finding] | None: