|
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as T |
|
from torchvision.utils import draw_segmentation_masks, draw_bounding_boxes |
|
import random |
|
import gradio as gr |
|
import numpy as np |
|
|
|
|
|
def random_color_gen(n): |
|
return [tuple(random.randint(0,255) for i in range(3)) for i in range(n)] |
|
|
|
import math |
|
|
|
def get_gaussian_kernel(kernel_size=15, sigma=20, channels=3): |
|
|
|
x_coord = torch.arange(kernel_size) |
|
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) |
|
y_grid = x_grid.t() |
|
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() |
|
|
|
mean = (kernel_size - 1)/2. |
|
variance = sigma**2. |
|
|
|
|
|
|
|
|
|
gaussian_kernel = (1./(2.*math.pi*variance)) *\ |
|
torch.exp( |
|
-torch.sum((xy_grid - mean)**2., dim=-1) /\ |
|
(2*variance) |
|
) |
|
|
|
|
|
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) |
|
|
|
|
|
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) |
|
gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) |
|
|
|
gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels, |
|
kernel_size=kernel_size, padding='same', groups=channels, bias=False) |
|
|
|
gaussian_filter.weight.data = gaussian_kernel |
|
gaussian_filter.weight.requires_grad = False |
|
|
|
return gaussian_filter |
|
|
|
|
|
output_dict = {} |
|
pred_label_unq = [] |
|
|
|
def segment(input_image): |
|
|
|
|
|
display_img = torch.tensor(np.asarray(input_image)).unsqueeze(0) |
|
display_img = display_img.permute(0, 3, 1, 2).squeeze(0) |
|
|
|
|
|
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1 |
|
transforms = weights.transforms() |
|
model = maskrcnn_resnet50_fpn_v2(weights=weights) |
|
model = model.eval(); |
|
|
|
|
|
input_tensor = transforms(input_image).unsqueeze(0) |
|
|
|
|
|
output = model(input_tensor)[0] |
|
|
|
|
|
|
|
score_threshold = 0.75 |
|
mask_threshold = 0.5 |
|
masks = output['masks'][output['scores'] > score_threshold] > mask_threshold; |
|
boxes = output['boxes'][output['scores'] > score_threshold] |
|
masks = masks.squeeze(1) |
|
boxes = boxes.squeeze(1) |
|
|
|
pred_labels = [weights.meta["categories"][label] for label in output['labels'][output['scores'] > score_threshold]] |
|
n_pred = len(pred_labels) |
|
|
|
|
|
pred_label_unq = [pred_labels[i] + str(pred_labels[:i].count(pred_labels[i]) + 1) for i in range(n_pred)] |
|
|
|
colors = random_color_gen(n_pred) |
|
|
|
|
|
for i in range(n_pred): |
|
output_dict[pred_label_unq[i]] = {'mask': masks[i].tolist(), 'color': colors[i]} |
|
|
|
|
|
masked_img = draw_segmentation_masks(display_img, masks, alpha=0.9, colors=colors) |
|
bounding_box_img = draw_bounding_boxes(masked_img, boxes, labels=pred_label_unq, colors='white') |
|
masked_img = T.ToPILImage()(masked_img) |
|
bounding_box_img = T.ToPILImage()(bounding_box_img) |
|
|
|
return bounding_box_img; |
|
|
|
def blur_background(input_image, label_name): |
|
mask = output_dict[label_name]['mask'] |
|
mask = torch.tensor(mask).unsqueeze(0) |
|
|
|
input_tensor = T.ToTensor()(input_image).unsqueeze(0) |
|
blur = get_gaussian_kernel() |
|
blurred_tensor = blur(input_tensor) |
|
|
|
final_img = blurred_tensor |
|
final_img[:, :, mask.squeeze(0)] = input_tensor[:, :, mask.squeeze(0)]; |
|
|
|
final_img = T.ToPILImage()(final_img.squeeze(0)) |
|
|
|
return final_img; |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
|
|
gr.Markdown("# Blur an objects background with AI") |
|
|
|
gr.Markdown("First segment the image and create bounding boxes") |
|
with gr.Column(): |
|
input_image = gr.Image(type='pil') |
|
b1 = gr.Button("Segment Image") |
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
bounding_box_image = gr.Image(); |
|
|
|
|
|
gr.Markdown("Now choose a label (eg: person1) from the above image of your desired object and input it below") |
|
with gr.Column(): |
|
label_name = gr.Textbox() |
|
b2 = gr.Button("Blur Backbround") |
|
result = gr.Image() |
|
|
|
b1.click(segment, inputs=input_image, outputs=bounding_box_image) |
|
b2.click(blur_background, inputs=[input_image, label_name], outputs=result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
app.launch(debug=True) |