import gradio as gr from typing import List import cv2 import torch import numpy as np from pytorch_grad_cam.utils.image import show_cam_on_image from models import YoloV3Lightning from utils import load_model_from_checkpoint import utils import config as cfg import matplotlib.pyplot as plt import matplotlib.patches as patches from grad_cam import YoloGradCAM device = torch.device('cpu') dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \ (0.2470, 0.2435, 0.2616) model = YoloV3Lightning.YOLOv3LightningModel(num_classes=cfg.NUM_CLASSES, anchors=cfg.ANCHORS, S=cfg.S) ckpt_file = 'ckpt_light2.pth' checkpoint = load_model_from_checkpoint(device, file_name=ckpt_file) model.load_state_dict(checkpoint['model'], strict=False) model.eval() scaled_anchors = ( torch.tensor(cfg.ANCHORS) * torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) ).to(model.device) cam = YoloGradCAM(model=model, target_layers=[model.layers[-2]], scaled_anchors=scaled_anchors, use_cuda=False) sample_images = [ ['images/000001.jpg'], ['images/000002.jpg'], ['images/000003.jpg'], ['images/000004.jpg'], ['images/000005.jpg'], ['images/000006.jpg'], ['images/000007.jpg'], ['images/000008.jpg'], ['images/000009.jpg'], ['images/000010.jpg'], ['images/000011.jpg'], ['images/000012.jpg'], ['images/000013.jpg'], ['images/000014.jpg'], ['images/000015.jpg'], ['images/000016.jpg'], ['images/000017.jpg'], ['images/000018.jpg'], ['images/000019.jpg'], ['images/000020.jpg'], ['images/000021.jpg'] ] with gr.Blocks() as app: with gr.Row(): gr.Markdown( """ # YoloV3 App! ## Model is trained on PASCAL-VOC data to predict following classes - """) with gr.Row(): gr.HTML( """
aeroplane bicycle bird boat bottle bus car cat
chair cow diningtable dog horse motorbike person pottedplant
sheep sofa train tvmonitor

Click to see the model architecture / code

""" ) with gr.Row(visible=True) as pred_cls_col: with gr.Column(): example_images = gr.Gallery(allow_preview=False, label='Select image ', value=[img[0] for img in sample_images], columns=6, rows=2) with gr.Column(): with gr.Row(): pred_image = gr.Image(label='Upload Image or Select from the gallery') with gr.Row(): if_show_grad_cam = gr.Checkbox(value=True, label='Show Class Activation Map (What the model sees)?') with gr.Row(): submit_btn = gr.Button("Submit", variant='primary') clear_btn = gr.ClearButton() with gr.Row(visible=True) as output_bk: with gr.Column(visible=True) as output_bk: output_img = gr.Image(interactive=False, label='Prediction Output') with gr.Column(visible=True) as output_bk: grad_cam_out = gr.Image(interactive=False, visible=True, label='CAM Outcome') def show_cam_output(input): return { grad_cam_out: gr.update(visible=input) } if_show_grad_cam.change( show_cam_output, if_show_grad_cam, grad_cam_out ) def clear_data(): return { pred_image: None, output_img: None, grad_cam_out: None } clear_btn.click(clear_data, None, [pred_image, output_img]) def on_select(evt: gr.SelectData): return { pred_image: sample_images[evt.index][0] } example_images.select(on_select, None, pred_image) def plot_image(image, boxes): """Plots predicted bounding boxes on the image""" cmap = plt.get_cmap("tab20b") class_labels = cfg.PASCAL_CLASSES colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))] im = np.array(image) height, width, _ = im.shape # Create figure and axes fig, ax = plt.subplots(1) # Display the image ax.imshow(im) # box[0] is x midpoint, box[2] is width # box[1] is y midpoint, box[3] is height # Create a Rectangle patch for box in boxes: assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height" class_pred = box[0] box = box[2:] upper_left_x = box[0] - box[2] / 2 upper_left_y = box[1] - box[3] / 2 rect = patches.Rectangle( (upper_left_x * width, upper_left_y * height), box[2] * width, box[3] * height, linewidth=2, edgecolor=colors[int(class_pred)], facecolor="none", ) # Add the patch to the Axes ax.add_patch(rect) plt.text( upper_left_x * width, upper_left_y * height, s=class_labels[int(class_pred)], color="white", verticalalignment="top", bbox={"color": colors[int(class_pred)], "pad": 0}, ) plt.savefig('output.png') x = plt.show() def predict(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.6, show_cam: bool = False, transparency: float = 0.5) -> List[np.ndarray]: with torch.no_grad(): transformed_image = cfg.grad_cam_transforms(image=image)["image"].unsqueeze(0) output = model(transformed_image) bboxes = [[] for _ in range(1)] for i in range(3): batch_size, A, S, _, _ = output[i].shape anchor = scaled_anchors[i] boxes_scale_i = utils.cells_to_bboxes( output[i], anchor, S=S, is_preds=True ) for idx, (box) in enumerate(boxes_scale_i): bboxes[idx] += box nms_boxes = utils.non_max_suppression( bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", ) plot_image(image, nms_boxes) plotted_img = 'output.png' if not show_cam: return [plotted_img, None] grayscale_cam = cam(transformed_image) img = np.array(transformed_image[0], np.float16).transpose(1, 2, 0) cam_image = show_cam_on_image(img, grayscale_cam.transpose(1, 2, 0), use_rgb=True, image_weight=transparency) return [plotted_img, cam_image] def img_upload(input_img, if_cam): if input_img is not None: imgs = predict(input_img, show_cam=if_cam) return { output_img: imgs[0], grad_cam_out: imgs[1] } submit_btn.click( img_upload, [pred_image, if_show_grad_cam], [output_img, grad_cam_out] ) ''' Launch the app ''' app.launch()