vade-detectron / app.py
vadepark's picture
Update app.py
782c8bc
try:
import detectron2
except:
import os
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
from PIL import Image, ImageDraw
import numpy as np
import requests
import torch
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, GenericMask, _create_text_labels
from detectron2.data import MetadataCatalog
MODEL_NAME = 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml'
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file(MODEL_NAME))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as w ell
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(MODEL_NAME)
if not torch.cuda.is_available():
cfg.MODEL.DEVICE = 'cpu'
example_image_url = "https://i.ibb.co/0QFxwjR/0a2e59fa-7990-43dc-b060-e8413468d113.jpg"
r = requests.get(example_image_url, allow_redirects=True)
open("city1.jpg", 'wb').write(r.content)
predictor = DefaultPredictor(cfg)
def infer_and_get_json(image: Image) -> dict:
img_width, img_height = image.size
image_real_size = (img_height, img_width)
np_image = np.array(image)
outputs = predictor(np_image)
predictions = outputs["instances"]
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
scores = predictions.scores if predictions.has("scores") else None
v = Visualizer(image, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
class_names = v.metadata.thing_classes
classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
masks = np.asarray(predictions.pred_masks)
masks = [GenericMask(x, image_real_size[0], image_real_size[1]) for x in masks]
num_instances = len(predictions)
# labels = _create_text_labels(classes, scores, class_names)
# DO IT BITCH
# pil_image = Image.open(img_path)
# img_draw = ImageDraw.Draw(pil_image)
classes_in_image_set = set()
return_dict = {
"all_class_options": class_names,
"instances": []
}
for i in range(num_instances):
# get polygon for instance
mask = masks[i]
all_polygons_for_instance = mask.polygons
polygon_to_draw_raw = None
for polygon_raw in all_polygons_for_instance:
polygon_to_draw_raw = polygon_raw
polygon_wrong_form = polygon_to_draw_raw.reshape(-1, 2).tolist()
polygon = []
for point in polygon_wrong_form:
polygon.append((int(point[0]), int(point[1])))
# get other infor about instance
score = scores[i].item()
class_index = classes[i]
instance_name = class_names[class_index]
# easy_label = labels[i]
box_raw = boxes[i]
box_list_float = box_raw.tensor.tolist()[0]
box_list_int = [int(x) for x in box_list_float]
classes_in_image_set.add(instance_name)
instance_dict = {
"confidence": score,
"class_index": class_index,
"class_name": instance_name,
"bounding_box": box_list_int,
"polygon": polygon
}
return_dict["instances"].append(instance_dict)
# img_draw.polygon(polygon, outline="blue")
# img_draw.rectangle(box_list_int, outline="red")
# top_left_box = (box_list_int[0], box_list_int[1] - 10)
# img_draw.text(top_left_box, easy_label)
# pil_image.show()
return_dict["classes_in_image"] = list(classes_in_image_set)
return return_dict
title = "VADE DETECTRON BABY"
description = "demo for Detectron2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.\
</br><b>Model: COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml</b>"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2012.07177'>Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation</a> | <a href='https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md'>Detectron model ZOO</a></p>"
import gradio as gr
gr.Interface(
infer_and_get_json,
[gr.inputs.Image(type="pil", label="Input")],
gr.outputs.JSON(label="Output"),
title=title,
description=description,
article=article,
examples=[
["city1.jpg"],
]).launch()