akhaliq3
commited on
Commit
·
4a7bfa8
1
Parent(s):
506da10
app file
Browse files
app.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
from matplotlib import gridspec
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import urllib
|
9 |
+
import tensorflow as tf
|
10 |
+
import gradio as gr
|
11 |
+
from subprocess import call
|
12 |
+
import sys
|
13 |
+
import requests
|
14 |
+
url1 = 'https://cdn.pixabay.com/photo/2014/09/07/21/52/city-438393_1280.jpg'
|
15 |
+
r = requests.get(url1, allow_redirects=True)
|
16 |
+
open("city1.jpg", 'wb').write(r.content)
|
17 |
+
url2 = 'https://cdn.pixabay.com/photo/2016/02/19/11/36/canal-1209808_1280.jpg'
|
18 |
+
r = requests.get(url2, allow_redirects=True)
|
19 |
+
open("city2.jpg", 'wb').write(r.content)
|
20 |
+
DatasetInfo = collections.namedtuple(
|
21 |
+
'DatasetInfo',
|
22 |
+
'num_classes, label_divisor, thing_list, colormap, class_names')
|
23 |
+
def _cityscapes_label_colormap():
|
24 |
+
"""Creates a label colormap used in CITYSCAPES segmentation benchmark.
|
25 |
+
See more about CITYSCAPES dataset at https://www.cityscapes-dataset.com/
|
26 |
+
M. Cordts, et al. "The Cityscapes Dataset for Semantic Urban Scene Understanding." CVPR. 2016.
|
27 |
+
Returns:
|
28 |
+
A 2-D numpy array with each row being mapped RGB color (in uint8 range).
|
29 |
+
"""
|
30 |
+
colormap = np.zeros((256, 3), dtype=np.uint8)
|
31 |
+
colormap[0] = [128, 64, 128]
|
32 |
+
colormap[1] = [244, 35, 232]
|
33 |
+
colormap[2] = [70, 70, 70]
|
34 |
+
colormap[3] = [102, 102, 156]
|
35 |
+
colormap[4] = [190, 153, 153]
|
36 |
+
colormap[5] = [153, 153, 153]
|
37 |
+
colormap[6] = [250, 170, 30]
|
38 |
+
colormap[7] = [220, 220, 0]
|
39 |
+
colormap[8] = [107, 142, 35]
|
40 |
+
colormap[9] = [152, 251, 152]
|
41 |
+
colormap[10] = [70, 130, 180]
|
42 |
+
colormap[11] = [220, 20, 60]
|
43 |
+
colormap[12] = [255, 0, 0]
|
44 |
+
colormap[13] = [0, 0, 142]
|
45 |
+
colormap[14] = [0, 0, 70]
|
46 |
+
colormap[15] = [0, 60, 100]
|
47 |
+
colormap[16] = [0, 80, 100]
|
48 |
+
colormap[17] = [0, 0, 230]
|
49 |
+
colormap[18] = [119, 11, 32]
|
50 |
+
return colormap
|
51 |
+
def _cityscapes_class_names():
|
52 |
+
return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
53 |
+
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
54 |
+
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
55 |
+
'bicycle')
|
56 |
+
def cityscapes_dataset_information():
|
57 |
+
return DatasetInfo(
|
58 |
+
num_classes=19,
|
59 |
+
label_divisor=1000,
|
60 |
+
thing_list=tuple(range(11, 19)),
|
61 |
+
colormap=_cityscapes_label_colormap(),
|
62 |
+
class_names=_cityscapes_class_names())
|
63 |
+
def perturb_color(color, noise, used_colors, max_trials=50, random_state=None):
|
64 |
+
"""Pertrubs the color with some noise.
|
65 |
+
If `used_colors` is not None, we will return the color that has
|
66 |
+
not appeared before in it.
|
67 |
+
Args:
|
68 |
+
color: A numpy array with three elements [R, G, B].
|
69 |
+
noise: Integer, specifying the amount of perturbing noise (in uint8 range).
|
70 |
+
used_colors: A set, used to keep track of used colors.
|
71 |
+
max_trials: An integer, maximum trials to generate random color.
|
72 |
+
random_state: An optional np.random.RandomState. If passed, will be used to
|
73 |
+
generate random numbers.
|
74 |
+
Returns:
|
75 |
+
A perturbed color that has not appeared in used_colors.
|
76 |
+
"""
|
77 |
+
if random_state is None:
|
78 |
+
random_state = np.random
|
79 |
+
for _ in range(max_trials):
|
80 |
+
random_color = color + random_state.randint(
|
81 |
+
low=-noise, high=noise + 1, size=3)
|
82 |
+
random_color = np.clip(random_color, 0, 255)
|
83 |
+
if tuple(random_color) not in used_colors:
|
84 |
+
used_colors.add(tuple(random_color))
|
85 |
+
return random_color
|
86 |
+
print('Max trial reached and duplicate color will be used. Please consider '
|
87 |
+
'increase noise in `perturb_color()`.')
|
88 |
+
return random_color
|
89 |
+
def color_panoptic_map(panoptic_prediction, dataset_info, perturb_noise):
|
90 |
+
"""Helper method to colorize output panoptic map.
|
91 |
+
Args:
|
92 |
+
panoptic_prediction: A 2D numpy array, panoptic prediction from deeplab
|
93 |
+
model.
|
94 |
+
dataset_info: A DatasetInfo object, dataset associated to the model.
|
95 |
+
perturb_noise: Integer, the amount of noise (in uint8 range) added to each
|
96 |
+
instance of the same semantic class.
|
97 |
+
Returns:
|
98 |
+
colored_panoptic_map: A 3D numpy array with last dimension of 3, colored
|
99 |
+
panoptic prediction map.
|
100 |
+
used_colors: A dictionary mapping semantic_ids to a set of colors used
|
101 |
+
in `colored_panoptic_map`.
|
102 |
+
"""
|
103 |
+
if panoptic_prediction.ndim != 2:
|
104 |
+
raise ValueError('Expect 2-D panoptic prediction. Got {}'.format(
|
105 |
+
panoptic_prediction.shape))
|
106 |
+
semantic_map = panoptic_prediction // dataset_info.label_divisor
|
107 |
+
instance_map = panoptic_prediction % dataset_info.label_divisor
|
108 |
+
height, width = panoptic_prediction.shape
|
109 |
+
colored_panoptic_map = np.zeros((height, width, 3), dtype=np.uint8)
|
110 |
+
used_colors = collections.defaultdict(set)
|
111 |
+
# Use a fixed seed to reproduce the same visualization.
|
112 |
+
random_state = np.random.RandomState(0)
|
113 |
+
unique_semantic_ids = np.unique(semantic_map)
|
114 |
+
for semantic_id in unique_semantic_ids:
|
115 |
+
semantic_mask = semantic_map == semantic_id
|
116 |
+
if semantic_id in dataset_info.thing_list:
|
117 |
+
# For `thing` class, we will add a small amount of random noise to its
|
118 |
+
# correspondingly predefined semantic segmentation colormap.
|
119 |
+
unique_instance_ids = np.unique(instance_map[semantic_mask])
|
120 |
+
for instance_id in unique_instance_ids:
|
121 |
+
instance_mask = np.logical_and(semantic_mask,
|
122 |
+
instance_map == instance_id)
|
123 |
+
random_color = perturb_color(
|
124 |
+
dataset_info.colormap[semantic_id],
|
125 |
+
perturb_noise,
|
126 |
+
used_colors[semantic_id],
|
127 |
+
random_state=random_state)
|
128 |
+
colored_panoptic_map[instance_mask] = random_color
|
129 |
+
else:
|
130 |
+
# For `stuff` class, we use the defined semantic color.
|
131 |
+
colored_panoptic_map[semantic_mask] = dataset_info.colormap[semantic_id]
|
132 |
+
used_colors[semantic_id].add(tuple(dataset_info.colormap[semantic_id]))
|
133 |
+
return colored_panoptic_map, used_colors
|
134 |
+
def vis_segmentation(image,
|
135 |
+
panoptic_prediction,
|
136 |
+
dataset_info,
|
137 |
+
perturb_noise=60):
|
138 |
+
"""Visualizes input image, segmentation map and overlay view."""
|
139 |
+
plt.figure(figsize=(30, 20))
|
140 |
+
grid_spec = gridspec.GridSpec(2, 2)
|
141 |
+
ax = plt.subplot(grid_spec[0])
|
142 |
+
plt.imshow(image)
|
143 |
+
plt.axis('off')
|
144 |
+
ax.set_title('input image', fontsize=20)
|
145 |
+
ax = plt.subplot(grid_spec[1])
|
146 |
+
panoptic_map, used_colors = color_panoptic_map(panoptic_prediction,
|
147 |
+
dataset_info, perturb_noise)
|
148 |
+
plt.imshow(panoptic_map)
|
149 |
+
plt.axis('off')
|
150 |
+
ax.set_title('panoptic map', fontsize=20)
|
151 |
+
ax = plt.subplot(grid_spec[2])
|
152 |
+
plt.imshow(image)
|
153 |
+
plt.imshow(panoptic_map, alpha=0.7)
|
154 |
+
plt.axis('off')
|
155 |
+
ax.set_title('panoptic overlay', fontsize=20)
|
156 |
+
ax = plt.subplot(grid_spec[3])
|
157 |
+
max_num_instances = max(len(color) for color in used_colors.values())
|
158 |
+
# RGBA image as legend.
|
159 |
+
legend = np.zeros((len(used_colors), max_num_instances, 4), dtype=np.uint8)
|
160 |
+
class_names = []
|
161 |
+
for i, semantic_id in enumerate(sorted(used_colors)):
|
162 |
+
legend[i, :len(used_colors[semantic_id]), :3] = np.array(
|
163 |
+
list(used_colors[semantic_id]))
|
164 |
+
legend[i, :len(used_colors[semantic_id]), 3] = 255
|
165 |
+
if semantic_id < dataset_info.num_classes:
|
166 |
+
class_names.append(dataset_info.class_names[semantic_id])
|
167 |
+
else:
|
168 |
+
class_names.append('ignore')
|
169 |
+
plt.imshow(legend, interpolation='nearest')
|
170 |
+
ax.yaxis.tick_left()
|
171 |
+
plt.yticks(range(len(legend)), class_names, fontsize=15)
|
172 |
+
plt.xticks([], [])
|
173 |
+
ax.tick_params(width=0.0, grid_linewidth=0.0)
|
174 |
+
plt.grid('off')
|
175 |
+
return plt
|
176 |
+
def run_cmd(command):
|
177 |
+
try:
|
178 |
+
print(command)
|
179 |
+
call(command, shell=True)
|
180 |
+
except KeyboardInterrupt:
|
181 |
+
print("Process interrupted")
|
182 |
+
sys.exit(1)
|
183 |
+
MODEL_NAME = 'resnet50_os32_panoptic_deeplab_cityscapes_crowd_trainfine_saved_model'
|
184 |
+
_MODELS = ('resnet50_os32_panoptic_deeplab_cityscapes_crowd_trainfine_saved_model',
|
185 |
+
'resnet50_beta_os32_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
186 |
+
'wide_resnet41_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
187 |
+
'swidernet_sac_1_1_1_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
188 |
+
'swidernet_sac_1_1_3_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
189 |
+
'swidernet_sac_1_1_4.5_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
|
190 |
+
'axial_swidernet_1_1_1_os16_axial_deeplab_cityscapes_trainfine_saved_model',
|
191 |
+
'axial_swidernet_1_1_3_os16_axial_deeplab_cityscapes_trainfine_saved_model',
|
192 |
+
'axial_swidernet_1_1_4.5_os16_axial_deeplab_cityscapes_trainfine_saved_model',
|
193 |
+
'max_deeplab_s_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model',
|
194 |
+
'max_deeplab_l_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model')
|
195 |
+
_DOWNLOAD_URL_PATTERN = 'https://storage.googleapis.com/gresearch/tf-deeplab/saved_model/%s.tar.gz'
|
196 |
+
_MODEL_NAME_TO_URL_AND_DATASET = {
|
197 |
+
model: (_DOWNLOAD_URL_PATTERN % model, cityscapes_dataset_information())
|
198 |
+
for model in _MODELS
|
199 |
+
}
|
200 |
+
MODEL_URL, DATASET_INFO = _MODEL_NAME_TO_URL_AND_DATASET[MODEL_NAME]
|
201 |
+
model_dir = tempfile.mkdtemp()
|
202 |
+
download_path = os.path.join(model_dir, MODEL_NAME + '.gz')
|
203 |
+
urllib.request.urlretrieve(MODEL_URL, download_path)
|
204 |
+
run_cmd("tar -xzvf " + download_path + " -C " + model_dir)
|
205 |
+
LOADED_MODEL = tf.saved_model.load(os.path.join(model_dir, MODEL_NAME))
|
206 |
+
def inference(image):
|
207 |
+
image = image.resize(size=(512, 512))
|
208 |
+
im = np.array(image)
|
209 |
+
output = LOADED_MODEL(tf.cast(im, tf.uint8))
|
210 |
+
return vis_segmentation(im, output['panoptic_pred'][0], DATASET_INFO)
|
211 |
+
title = "Deeplab2"
|
212 |
+
description = "demo for Deeplab2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
213 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.09748'>DeepLab2: A TensorFlow Library for Deep Labeling</a> | <a href='https://github.com/google-research/deeplab2'>Github Repo</a></p>"
|
214 |
+
gr.Interface(
|
215 |
+
inference,
|
216 |
+
[gr.inputs.Image(type="pil", label="Input")],
|
217 |
+
gr.outputs.Image(type="plot", label="Output"),
|
218 |
+
title=title,
|
219 |
+
description=description,
|
220 |
+
article=article,
|
221 |
+
examples=[
|
222 |
+
["city1.jpg"],
|
223 |
+
["city2.jpg"]
|
224 |
+
]).launch()
|