yolov3 / app.py
piyushgrover's picture
Update app.py
3252151
raw
history blame
8.28 kB
import gradio as gr
from typing import List
import cv2
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import io
from models import YoloV3Lightning
from utils import load_model_from_checkpoint
import utils
import config as cfg
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from dataset import YOLODataset
from torch.utils.data import Dataset, DataLoader
from grad_cam import YoloGradCAM
device = torch.device('cpu')
dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \
(0.2470, 0.2435, 0.2616)
model = YoloV3Lightning.YOLOv3LightningModel(num_classes=cfg.NUM_CLASSES, anchors=cfg.ANCHORS, S=cfg.S)
ckpt_file = 'ckpt_light.pth'
checkpoint = load_model_from_checkpoint(device, file_name=ckpt_file)
model.load_state_dict(checkpoint['model'], strict=False)
model.eval()
scaled_anchors = (
torch.tensor(cfg.ANCHORS)
* torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(model.device)
cam = YoloGradCAM(model=model, target_layers=[model.layers[-2]], scaled_anchors=scaled_anchors, use_cuda=False)
sample_images = [
['images/000001.jpg'],
['images/000002.jpg'],
['images/000003.jpg'],
['images/000004.jpg'],
['images/000005.jpg'],
['images/000006.jpg'],
['images/000007.jpg'],
['images/000008.jpg'],
['images/000009.jpg'],
['images/000010.jpg'],
['images/000011.jpg'],
['images/000012.jpg'],
['images/000013.jpg'],
['images/000014.jpg'],
['images/000015.jpg'],
['images/000016.jpg'],
['images/000017.jpg'],
['images/000018.jpg'],
['images/000019.jpg'],
['images/000020.jpg'],
['images/000021.jpg'],
['images/000022.jpg'],
['images/000023.jpg'],
['images/000024.jpg'],
['images/000025.jpg']
]
with gr.Blocks() as app:
with gr.Row():
gr.Markdown(
"""
# YoloV3 App!
## Model is trained on PASCAL-VOC data to predict following classes -
""")
with gr.Row():
gr.HTML(
"""
<table>
<tr>
<th>aeroplane</th>
<th>bicycle</th>
<th>bird</th>
<th>boat</th>
<th>bottle</th>
<th>bus</th>
<th>car</th>
<th>cat</th>
</tr>
<tr>
<th>chair</th>
<th>cow</th>
<th>diningtable</th>
<th>dog</th>
<th>horse</th>
<th>motorbike</th>
<th>person</th>
<th>pottedplant</th>
</tr>
<tr>
<th>sheep</th>
<th>sofa</th>
<th>train</th>
<th>tvmonitor</th>
</tr>
</table>
<p>
<a href='https://github.com/piygr/yolov3/blob/main/models/YoloV3Lightning.py'>Click to see the model architecture / code </a>
</p>
"""
)
with gr.Row(visible=True) as top_pred_cls_col:
with gr.Column():
example_images = gr.Gallery(allow_preview=False, label='Select image ', info='',
value=[img[0] for img in sample_images], columns=5, rows=2)
with gr.Column():
top_pred_image = gr.Image(label='Upload Image or Select from the gallery')
with gr.Row():
top_class_btn = gr.Button("Submit", variant='primary')
tc_clear_btn = gr.ClearButton()
with gr.Row():
if_show_grad_cam = gr.Checkbox(value=True, label='Show Class Activation Map (What the model sees)?')
with gr.Row(visible=True) as top_class_output:
with gr.Column(visible=True) as top_class_output:
top_class_output_img = gr.Image(interactive=False, label='Prediction Output')
with gr.Column(visible=True) as top_class_output:
grad_cam_out = gr.Image(interactive=False, visible=True, label='CAM Outcome')
def show_cam_output(input):
return {
grad_cam_out: gr.update(visible=input)
}
if_show_grad_cam.change(
show_cam_output,
if_show_grad_cam,
grad_cam_out
)
def clear_data():
return {
top_pred_image: None,
top_class_output_img: None
}
tc_clear_btn.click(clear_data, None, [top_pred_image, top_class_output_img])
def on_select(evt: gr.SelectData):
return {
top_pred_image: sample_images[evt.index][0]
}
example_images.select(on_select, None, top_pred_image)
def plot_image(image, boxes):
"""Plots predicted bounding boxes on the image"""
cmap = plt.get_cmap("tab20b")
class_labels = cfg.PASCAL_CLASSES
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
im = np.array(image)
height, width, _ = im.shape
# Create figure and axes
fig, ax = plt.subplots(1)
# Display the image
ax.imshow(im)
# box[0] is x midpoint, box[2] is width
# box[1] is y midpoint, box[3] is height
# Create a Rectangle patch
for box in boxes:
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
class_pred = box[0]
box = box[2:]
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
linewidth=2,
edgecolor=colors[int(class_pred)],
facecolor="none",
)
# Add the patch to the Axes
ax.add_patch(rect)
plt.text(
upper_left_x * width,
upper_left_y * height,
s=class_labels[int(class_pred)],
color="white",
verticalalignment="top",
bbox={"color": colors[int(class_pred)], "pad": 0},
)
plt.savefig('output.png')
x = plt.show()
def predict(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.6, show_cam: bool = False,
transparency: float = 0.5) -> List[np.ndarray]:
with torch.no_grad():
transformed_image = cfg.grad_cam_transforms(image=image)["image"].unsqueeze(0)
output = model(transformed_image)
bboxes = [[] for _ in range(1)]
for i in range(3):
batch_size, A, S, _, _ = output[i].shape
anchor = scaled_anchors[i]
boxes_scale_i = utils.cells_to_bboxes(
output[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
nms_boxes = utils.non_max_suppression(
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
)
plot_image(image, nms_boxes)
plotted_img = 'output.png'
if not show_cam:
return [plotted_img, None]
grayscale_cam = cam(transformed_image)[0, :, :]
img = cv2.resize(image, (416, 416))
img = np.float32(img) / 255
cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency)
return [plotted_img, cam_image]
def top_class_img_upload(input_img, if_cam):
if input_img is not None:
imgs = predict(input_img, show_cam=if_cam)
return {
top_class_output_img: imgs[0],
grad_cam_out: imgs[1]
}
top_class_btn.click(
top_class_img_upload,
[top_pred_image, if_show_grad_cam],
[top_class_output_img, grad_cam_out]
)
'''
Launch the app
'''
app.launch()