Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import traceback | |
from typing import Literal, Optional | |
import cv2 | |
import matplotlib.patches as patches | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
from utils import * | |
# --- Utility Functions (kept outside the class) --- | |
def blur_image(img: np.ndarray): | |
"""Applies Gaussian blur to an image.""" | |
return cv2.GaussianBlur(img, (35, 35), 50) | |
def plot_polygon_mask(image: np.ndarray, polygons: list[list[tuple[int, int]]]): | |
""" | |
Plots polygon-based segmentation masks on top of an image. | |
""" | |
plt.imshow(image) | |
for polygon in polygons: | |
if not polygon: | |
continue # Skip empty polygons | |
polygon_array = np.array(polygon).reshape(-1, 2) | |
x, y = zip(*polygon_array) | |
x = list(x) + [x[0]] | |
y = list(y) + [y[0]] | |
plt.plot(x, y, "-r", linewidth=2) | |
plt.axis("off") | |
plt.tight_layout() | |
plt.show() | |
def visualize_boxes(image, findings): | |
"""Visualizes bounding boxes on an image.""" | |
fig, ax = plt.subplots(1) | |
ax.imshow(image) | |
colors = ["r", "g", "b", "c", "m", "y", "k"] | |
for i, finding in enumerate(findings): | |
[x_min, y_min, x_max, y_max] = finding.bounding_box | |
color = colors[i % len(colors)] | |
rect = patches.Rectangle( | |
(x_min, y_min), | |
x_max - x_min, | |
y_max - y_min, | |
linewidth=2, | |
edgecolor=color, | |
facecolor="none", | |
) | |
ax.add_patch(rect) | |
print(f"Finding {i + 1} (Color: {color}):") | |
if not findings: | |
print("No findings") | |
plt.xticks(np.arange(0, image.shape[1], 50)) | |
plt.yticks(np.arange(0, image.shape[0], 50)) | |
plt.show() | |
# --- SAM Visualization Helpers (kept outside the class) --- | |
def show_mask(mask, ax, random_color=False, borders=True): | |
"""Displays a single mask on a matplotlib axis.""" | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask = mask.astype(np.uint8) | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
if borders: | |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
# contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] # Optional smoothing | |
mask_image = cv2.drawContours( | |
mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2 | |
) | |
ax.imshow(mask_image) | |
def show_points(coords, labels, ax, marker_size=375): | |
"""Displays points (positive/negative) on a matplotlib axis.""" | |
pos_points = coords[labels == 1] | |
neg_points = coords[labels == 0] | |
ax.scatter( | |
pos_points[:, 0], | |
pos_points[:, 1], | |
color="green", | |
marker="*", | |
s=marker_size, | |
edgecolor="white", | |
linewidth=1.25, | |
) | |
ax.scatter( | |
neg_points[:, 0], | |
neg_points[:, 1], | |
color="red", | |
marker="*", | |
s=marker_size, | |
edgecolor="white", | |
linewidth=1.25, | |
) | |
def show_box(box, ax): | |
"""Displays a bounding box on a matplotlib axis.""" | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch( | |
plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2) | |
) | |
def show_masks( | |
image, | |
masks, | |
scores, | |
point_coords=None, | |
box_coords=None, | |
input_labels=None, | |
borders=True, | |
): | |
"""Displays multiple masks resulting from SAM prediction.""" | |
for i, (mask, score) in enumerate(zip(masks, scores)): | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(image) | |
show_mask(mask, plt.gca(), borders=borders) | |
if point_coords is not None: | |
assert input_labels is not None | |
show_points(point_coords, input_labels, plt.gca()) | |
if box_coords is not None: | |
show_box(box_coords, plt.gca()) | |
if len(scores) > 1: | |
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18) | |
plt.axis("off") | |
plt.show() | |
# --- ImageBlurnonymizer Class --- | |
class ImageBlurnonymizer: | |
def __init__(self): | |
self.predictor = None | |
self.device = None | |
self.init_sam() | |
def init_sam(self, force=False): | |
# only initialize SAM if it hasn't been initialized yet | |
if self.predictor is not None and not force: | |
return | |
# self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# self.device = "cuda" | |
self.device = "cuda" | |
# Set the device for PyTorch | |
self.predictor = SAM2ImagePredictor.from_pretrained( | |
"facebook/sam2.1-hiera-small", | |
device=self.device, | |
) | |
def _smoothen_mask(mask: np.ndarray): | |
"""Applies morphological closing to smoothen mask boundaries.""" | |
kernel = np.ones((20, 20), np.uint8) | |
return cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
def _mask_from_bbox(image_shape, bbox: tuple[int, int, int, int]): | |
"""Creates a simple rectangular mask from a bounding box.""" | |
height, width, *_ = image_shape # Allow for 2D or 3D shape tuple | |
xmin, ymin, xmax, ymax = bbox | |
mask = np.zeros((height, width), dtype=np.uint8) | |
mask[ymin:ymax, xmin:xmax] = 1 | |
return mask # No need for np.array() conversion | |
def _apply_blur_mask(image: np.ndarray, mask: np.ndarray): | |
"""Applies a blur to an image based on a mask.""" | |
if mask.ndim == 2: # Ensure mask is 3-channel for broadcasting | |
mask = np.stack((mask,) * image.shape[2], axis=-1) | |
blurred = blur_image(image) # Use the utility function | |
return np.where(mask, blurred, image) | |
def _binary_mask_to_polygon(binary_mask: np.ndarray, epsilon=2.0): | |
"""Converts a binary segmentation mask to polygon contours.""" | |
try: | |
converted = (binary_mask * 255).astype(np.uint8) | |
# Use RETR_TREE to get hierarchy, CHAIN_APPROX_SIMPLE for efficiency | |
contours, _ = cv2.findContours( | |
converted, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | |
) | |
polygons = [] | |
for contour in contours: | |
approx_contour = cv2.approxPolyDP(contour, epsilon, True) | |
# Ensure points are converted correctly | |
polygon = [ | |
(int(point[0][0]), int(point[0][1])) for point in approx_contour | |
] | |
polygons.append(polygon) | |
return polygons | |
except Exception as e: | |
print(f"An error occurred during polygon conversion: {e}") | |
print(traceback.format_exc()) | |
return None # Return None on error | |
def get_segmentation_mask(self, image: np.ndarray, bbox: tuple[int, int, int, int]): | |
""" | |
Generates a segmentation mask for a region defined by a bounding box using SAM. | |
Adds points within the bounding box to guide SAM towards the intended object (e.g., face) | |
and away from surrounding elements (e.g., hair). | |
""" | |
if self.predictor is None: | |
raise Exception("[-] sam has not been initialized") | |
# if torch.cuda.is_available() and self.device == "cpu": | |
# # class instance was wrongly initialized to run on cpu, but gpu is avaiable | |
# self.init_sam(force=True) | |
x_min, y_min, x_max, y_max = bbox | |
x_width = x_max - x_min | |
y_height = y_max - y_min # Corrected variable name | |
# Handle cases where box dimensions are too small for third calculations | |
x_third = x_width // 3 if x_width >= 3 else 0 | |
y_third = y_height // 3 if y_height >= 3 else 0 | |
center_point = [(x_min + x_max) // 2, (y_min + y_max) // 2] | |
# Define points ensuring they stay within the image boundaries | |
points = [center_point] | |
if y_third > 0: | |
points.append([center_point[0], center_point[1] - y_third]) | |
points.append([center_point[0], center_point[1] + y_third]) | |
if x_third > 0: | |
points.append([center_point[0] + x_third, center_point[1]]) | |
points.append([center_point[0] - x_third, center_point[1]]) | |
# Ensure points are valid coordinates (e.g., non-negative) | |
points = [[max(0, p[0]), max(0, p[1])] for p in points] | |
with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16): | |
self.predictor.set_image(image) | |
masks, scores, _ = self.predictor.predict( | |
box=np.array(bbox), # Predictor might expect numpy array | |
point_coords=np.array(points), | |
point_labels=np.ones(len(points)), # Label 1 for inclusion | |
multimask_output=True, | |
) | |
# Sort masks by score and select the best one | |
sorted_ind = np.argsort(scores)[::-1] | |
best_mask = masks[sorted_ind[0]] | |
best_score = scores[sorted_ind[0]] | |
return self._smoothen_mask(best_mask), best_score | |
def censor_image_blur( | |
self, | |
image: np.ndarray, | |
raw_out: str, | |
method: Optional[Literal["segmentation", "bbox"]] = "segmentation", | |
verbose=False, | |
): | |
""" | |
Censors an image by blurring regions identified in the raw_out string (LLM output). | |
""" | |
self.init_sam() | |
json_output = parse_json_response(raw_out) | |
# Ensure json_output is a list before passing to parse_into_models | |
if isinstance(json_output, dict): | |
findings_list = [json_output] | |
elif isinstance(json_output, list): | |
findings_list = json_output | |
else: | |
# Handle unexpected type or raise an error | |
print( | |
f"Warning: Unexpected output type from parse_json_response: {type(json_output)}" | |
) | |
findings_list = [] | |
parsed = parse_into_models(findings_list) # type: ignore | |
# Filter findings based on severity | |
filtered = [entry for entry in parsed if entry.severity > 0] | |
if verbose: | |
visualize_boxes(image, filtered) # Use external visualization | |
masks = [] | |
for finding in filtered: | |
bbox = ( | |
finding.bounding_box | |
) # Assuming finding has a 'bounding_box' attribute | |
if method == "segmentation": | |
mask, _ = self.get_segmentation_mask(image, bbox) # Use instance method | |
if verbose: | |
polygons = self._binary_mask_to_polygon(mask) | |
if polygons: # Check if polygon conversion was successful | |
plot_polygon_mask(image, polygons) # Use external visualization | |
elif method == "bbox": | |
mask = self._mask_from_bbox(image.shape, bbox) # Use static method | |
else: | |
print( | |
f"Warning: Unknown method '{method}'. Defaulting to no mask for this finding." | |
) | |
continue # Skip if method is invalid | |
masks.append(mask) | |
if masks: # Check if any masks were generated | |
# Combine masks: logical OR ensures any pixel in any mask is included | |
combined_mask = np.zeros_like(masks[0], dtype=np.uint8) | |
for mask in masks: | |
# Ensure masks are boolean or uint8 for logical_or | |
combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype( | |
np.uint8 | |
) | |
return self._apply_blur_mask(image, combined_mask) # Use static method | |
return image # Return original image if no masks | |
def censor_image_blur_easy( | |
self, | |
image: np.ndarray, | |
boxes: list[BoundingBox], | |
method: Optional[Literal["segmentation", "bbox"]] = "segmentation", | |
verbose=False, | |
): | |
""" | |
Censors an image by blurring regions defined by a list of BoundingBox objects. | |
""" | |
self.init_sam() | |
# method = "bbox" | |
masks = [] | |
for box in boxes: | |
bbox_tuple = box.to_tuple() # Convert BoundingBox object to tuple | |
if method == "segmentation": | |
mask, _ = self.get_segmentation_mask(image, bbox_tuple) | |
if verbose: | |
polygons = self._binary_mask_to_polygon(mask) | |
if polygons: | |
plot_polygon_mask(image, polygons) | |
elif method == "bbox": | |
mask = self._mask_from_bbox(image.shape, bbox_tuple) | |
else: | |
print( | |
f"Warning: Unknown method '{method}'. Defaulting to no mask for this box." | |
) | |
continue | |
masks.append(mask) | |
if masks: | |
combined_mask = np.zeros_like(masks[0], dtype=np.uint8) | |
for mask in masks: | |
combined_mask = np.logical_or(combined_mask, mask.astype(bool)).astype( | |
np.uint8 | |
) | |
return self._apply_blur_mask(image, combined_mask) | |
return image | |
# Example Usage (Optional - keep outside class): | |
# if __name__ == '__main__': | |
# # Load an image | |
# # img = cv2.imread('path/to/your/image.jpg') | |
# # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB for matplotlib | |
# # Create an instance of the blurnonymizer | |
# # blurnonymizer = ImageBlurnonymizer() | |
# # Define bounding boxes or get raw LLM output | |
# # example_boxes = [BoundingBox(xmin=100, ymin=100, xmax=200, ymax=200)] # Assuming BoundingBox class exists | |
# # llm_output = '...' # Your raw LLM output string | |
# # Censor the image | |
# # censored_img_easy = blurnonymizer.censor_image_blur_easy(img, example_boxes, method='segmentation', verbose=True) | |
# # censored_img_llm = blurnonymizer.censor_image_blur(img, llm_output, method='segmentation', verbose=True) | |
# # Display or save the result | |
# # plt.imshow(censored_img_easy) | |
# # plt.show() | |