Spaces:
Runtime error
Runtime error
try: | |
import detectron2 | |
except: | |
import os | |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
from PIL import Image, ImageDraw | |
import numpy as np | |
import requests | |
import torch | |
from detectron2 import model_zoo | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer, GenericMask, _create_text_labels | |
from detectron2.data import MetadataCatalog | |
MODEL_NAME = 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml' | |
cfg = get_cfg() | |
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library | |
cfg.merge_from_file(model_zoo.get_config_file(MODEL_NAME)) | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model | |
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as w ell | |
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(MODEL_NAME) | |
if not torch.cuda.is_available(): | |
cfg.MODEL.DEVICE = 'cpu' | |
example_image_url = "https://i.ibb.co/0QFxwjR/0a2e59fa-7990-43dc-b060-e8413468d113.jpg" | |
r = requests.get(example_image_url, allow_redirects=True) | |
open("city1.jpg", 'wb').write(r.content) | |
predictor = DefaultPredictor(cfg) | |
def infer_and_get_json(image: Image) -> dict: | |
img_width, img_height = image.size | |
image_real_size = (img_height, img_width) | |
np_image = np.array(image) | |
outputs = predictor(np_image) | |
predictions = outputs["instances"] | |
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None | |
scores = predictions.scores if predictions.has("scores") else None | |
v = Visualizer(image, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2) | |
class_names = v.metadata.thing_classes | |
classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None | |
masks = np.asarray(predictions.pred_masks) | |
masks = [GenericMask(x, image_real_size[0], image_real_size[1]) for x in masks] | |
num_instances = len(predictions) | |
# labels = _create_text_labels(classes, scores, class_names) | |
# DO IT BITCH | |
# pil_image = Image.open(img_path) | |
# img_draw = ImageDraw.Draw(pil_image) | |
classes_in_image_set = set() | |
return_dict = { | |
"all_class_options": class_names, | |
"instances": [] | |
} | |
for i in range(num_instances): | |
# get polygon for instance | |
mask = masks[i] | |
all_polygons_for_instance = mask.polygons | |
polygon_to_draw_raw = None | |
for polygon_raw in all_polygons_for_instance: | |
polygon_to_draw_raw = polygon_raw | |
polygon_wrong_form = polygon_to_draw_raw.reshape(-1, 2).tolist() | |
polygon = [] | |
for point in polygon_wrong_form: | |
polygon.append((int(point[0]), int(point[1]))) | |
# get other infor about instance | |
score = scores[i].item() | |
class_index = classes[i] | |
instance_name = class_names[class_index] | |
# easy_label = labels[i] | |
box_raw = boxes[i] | |
box_list_float = box_raw.tensor.tolist()[0] | |
box_list_int = [int(x) for x in box_list_float] | |
classes_in_image_set.add(instance_name) | |
instance_dict = { | |
"confidence": score, | |
"class_index": class_index, | |
"class_name": instance_name, | |
"bounding_box": box_list_int, | |
"polygon": polygon | |
} | |
return_dict["instances"].append(instance_dict) | |
# img_draw.polygon(polygon, outline="blue") | |
# img_draw.rectangle(box_list_int, outline="red") | |
# top_left_box = (box_list_int[0], box_list_int[1] - 10) | |
# img_draw.text(top_left_box, easy_label) | |
# pil_image.show() | |
return_dict["classes_in_image"] = list(classes_in_image_set) | |
return return_dict | |
title = "VADE DETECTRON BABY" | |
description = "demo for Detectron2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.\ | |
</br><b>Model: COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml</b>" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2012.07177'>Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation</a> | <a href='https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md'>Detectron model ZOO</a></p>" | |
import gradio as gr | |
gr.Interface( | |
infer_and_get_json, | |
[gr.inputs.Image(type="pil", label="Input")], | |
gr.outputs.JSON(label="Output"), | |
title=title, | |
description=description, | |
article=article, | |
examples=[ | |
["city1.jpg"], | |
]).launch() | |