vadepark commited on
Commit
1a7a852
·
1 Parent(s): 32df274

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -39
app.py CHANGED
@@ -1,72 +1,115 @@
1
- try:
2
- import detectron2
3
- except:
4
- import os
5
-
6
- os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
7
-
8
- from matplotlib.pyplot import axis
9
- import gradio as gr
10
- import requests
11
  import numpy as np
12
- from torch import nn
13
  import requests
14
-
15
  import torch
16
-
17
  from detectron2 import model_zoo
18
  from detectron2.engine import DefaultPredictor
19
  from detectron2.config import get_cfg
20
- from detectron2.utils.visualizer import Visualizer
21
  from detectron2.data import MetadataCatalog
22
 
23
- url1 = 'https://cdn.pixabay.com/photo/2014/09/07/21/52/city-438393_1280.jpg'
24
- r = requests.get(url1, allow_redirects=True)
25
- open("city1.jpg", 'wb').write(r.content)
26
- url2 = 'https://cdn.pixabay.com/photo/2016/02/19/11/36/canal-1209808_1280.jpg'
27
- r = requests.get(url2, allow_redirects=True)
28
- open("city2.jpg", 'wb').write(r.content)
29
 
30
- model_name = 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml'
31
-
32
- # model = model_zoo.get(model_name, trained=True)
33
 
34
  cfg = get_cfg()
35
  # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
36
- cfg.merge_from_file(model_zoo.get_config_file(model_name))
37
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
38
  # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as w ell
39
- cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model_name)
40
 
41
  if not torch.cuda.is_available():
42
  cfg.MODEL.DEVICE = 'cpu'
43
 
44
- predictor = DefaultPredictor(cfg)
45
-
46
 
47
- def inference(image):
48
- img = np.array(image.resize((1024, 1024)))
49
- outputs = predictor(img)
50
 
51
- v = Visualizer(img, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
52
- out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
53
 
54
- return out.get_image()
55
 
56
 
57
- title = "Detectron2-MaskRCNN X101"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  description = "demo for Detectron2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.\
59
  </br><b>Model: COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml</b>"
60
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2012.07177'>Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation</a> | <a href='https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md'>Detectron model ZOO</a></p>"
61
 
 
62
  gr.Interface(
63
- inference,
64
  [gr.inputs.Image(type="pil", label="Input")],
65
- gr.outputs.Image(type="numpy", label="Output"),
66
  title=title,
67
  description=description,
68
  article=article,
69
  examples=[
70
- ["city1.jpg"],
71
- ["city2.jpg"]
72
- ]).launch()
 
1
+ from PIL import Image, ImageDraw
 
 
 
 
 
 
 
 
 
2
  import numpy as np
 
3
  import requests
 
4
  import torch
 
5
  from detectron2 import model_zoo
6
  from detectron2.engine import DefaultPredictor
7
  from detectron2.config import get_cfg
8
+ from detectron2.utils.visualizer import Visualizer, GenericMask, _create_text_labels
9
  from detectron2.data import MetadataCatalog
10
 
11
+ MODEL_NAME = 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml'
 
 
 
 
 
12
 
 
 
 
13
 
14
  cfg = get_cfg()
15
  # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
16
+ cfg.merge_from_file(model_zoo.get_config_file(MODEL_NAME))
17
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
18
  # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as w ell
19
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(MODEL_NAME)
20
 
21
  if not torch.cuda.is_available():
22
  cfg.MODEL.DEVICE = 'cpu'
23
 
 
 
24
 
25
+ example_image_url = "https://i.ibb.co/0QFxwjR/0a2e59fa-7990-43dc-b060-e8413468d113.jpg"
26
+ r = requests.get(example_image_url, allow_redirects=True)
27
+ open("city1.jpg", 'wb').write(r.content)
28
 
 
 
29
 
30
+ predictor = DefaultPredictor(cfg)
31
 
32
 
33
+ def infer_and_get_json(image: Image) -> dict:
34
+ img_width, img_height = image.size
35
+ image_real_size = (img_height, img_width)
36
+ np_image = np.array(image)
37
+ outputs = predictor(np_image)
38
+
39
+ predictions = outputs["instances"]
40
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
41
+ scores = predictions.scores if predictions.has("scores") else None
42
+ v = Visualizer(image, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
43
+ class_names = v.metadata.thing_classes
44
+ classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
45
+ masks = np.asarray(predictions.pred_masks)
46
+ masks = [GenericMask(x, image_real_size[0], image_real_size[1]) for x in masks]
47
+
48
+ num_instances = len(predictions)
49
+ # labels = _create_text_labels(classes, scores, class_names)
50
+
51
+ # DO IT BITCH
52
+ # pil_image = Image.open(img_path)
53
+ # img_draw = ImageDraw.Draw(pil_image)
54
+ classes_in_image_set = set()
55
+ return_dict = {
56
+ "all_class_options": class_names,
57
+ "instances": []
58
+ }
59
+
60
+ for i in range(num_instances):
61
+ # get polygon for instance
62
+ mask = masks[i]
63
+ all_polygons_for_instance = mask.polygons
64
+ polygon_to_draw_raw = None
65
+ for polygon_raw in all_polygons_for_instance:
66
+ polygon_to_draw_raw = polygon_raw
67
+
68
+ polygon_wrong_form = polygon_to_draw_raw.reshape(-1, 2).tolist()
69
+ polygon = []
70
+ for point in polygon_wrong_form:
71
+ polygon.append((int(point[0]), int(point[1])))
72
+
73
+ # get other infor about instance
74
+ score = scores[i].item()
75
+ class_index = classes[i]
76
+ instance_name = class_names[class_index]
77
+ # easy_label = labels[i]
78
+ box_raw = boxes[i]
79
+ box_list_float = box_raw.tensor.tolist()[0]
80
+ box_list_int = [int(x) for x in box_list_float]
81
+ classes_in_image_set.add(instance_name)
82
+ instance_dict = {
83
+ "confidence": score,
84
+ "class_index": class_index,
85
+ "class_name": instance_name,
86
+ "bounding_box": box_list_int,
87
+ "polygon": polygon
88
+ }
89
+ return_dict["instances"].append(instance_dict)
90
+ # img_draw.polygon(polygon, outline="blue")
91
+ # img_draw.rectangle(box_list_int, outline="red")
92
+ # top_left_box = (box_list_int[0], box_list_int[1] - 10)
93
+ # img_draw.text(top_left_box, easy_label)
94
+ # pil_image.show()
95
+ return_dict["classes_in_image"] = list(classes_in_image_set)
96
+ return return_dict
97
+
98
+
99
+
100
+ title = "VADE DETECTRON BABY"
101
  description = "demo for Detectron2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.\
102
  </br><b>Model: COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml</b>"
103
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2012.07177'>Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation</a> | <a href='https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md'>Detectron model ZOO</a></p>"
104
 
105
+ import gradio as gr
106
  gr.Interface(
107
+ infer_and_get_json,
108
  [gr.inputs.Image(type="pil", label="Input")],
109
+ gr.outputs.JSON(label="Output"),
110
  title=title,
111
  description=description,
112
  article=article,
113
  examples=[
114
+ ["city1.jpg"],
115
+ ]).launch()