0jung commited on
Commit
7af2a8c
·
1 Parent(s): 76df5d9
Files changed (1) hide show
  1. app.py +159 -154
app.py CHANGED
@@ -5,172 +5,174 @@ import matplotlib.pyplot as plt
5
  import numpy as np
6
  from PIL import Image
7
  import tensorflow as tf
8
- from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
9
- import requests
10
 
 
11
 
12
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
13
  "nvidia/segformer-b5-finetuned-ade-640-640"
14
  )
15
- model = SegformerForSemanticSegmentation.from_pretrained(
16
  "nvidia/segformer-b5-finetuned-ade-640-640"
17
  )
18
 
 
19
  def ade_palette():
20
  """ADE20K palette that maps each class to RGB values."""
21
  return [
22
- [204, 87, 92],
23
- [112, 185, 212],
24
- [45, 189, 106],
25
- [234, 123, 67],
26
- [78, 56, 123],
27
- [210, 32, 89],
28
- [90, 180, 56],
29
- [155, 102, 200],
30
- [33, 147, 176],
31
- [255, 183, 76],
32
- [67, 123, 89],
33
- [190, 60, 45],
34
- [134, 112, 200],
35
- [56, 45, 189],
36
- [200, 56, 123],
37
- [87, 92, 204],
38
- [120, 56, 123],
39
- [45, 78, 123],
40
- [156, 200, 56],
41
- [32, 90, 210],
42
- [56, 123, 67],
43
- [180, 56, 123],
44
- [123, 67, 45],
45
- [45, 134, 200],
46
- [67, 56, 123],
47
- [78, 123, 67],
48
- [32, 210, 90],
49
- [45, 56, 189],
50
- [123, 56, 123],
51
- [56, 156, 200],
52
- [189, 56, 45],
53
- [112, 200, 56],
54
- [56, 123, 45],
55
- [200, 32, 90],
56
- [123, 45, 78],
57
- [200, 156, 56],
58
- [45, 67, 123],
59
- [56, 45, 78],
60
- [45, 56, 123],
61
- [123, 67, 56],
62
- [56, 78, 123],
63
- [210, 90, 32],
64
- [123, 56, 189],
65
- [45, 200, 134],
66
- [67, 123, 56],
67
- [123, 45, 67],
68
- [90, 32, 210],
69
- [200, 45, 78],
70
- [32, 210, 90],
71
- [45, 123, 67],
72
- [165, 42, 87],
73
- [72, 145, 167],
74
- [15, 158, 75],
75
- [209, 89, 40],
76
- [32, 21, 121],
77
- [184, 20, 100],
78
- [56, 135, 15],
79
- [128, 92, 176],
80
- [1, 119, 140],
81
- [220, 151, 43],
82
- [41, 97, 72],
83
- [148, 38, 27],
84
- [107, 86, 176],
85
- [21, 26, 136],
86
- [174, 27, 90],
87
- [91, 96, 204],
88
- [108, 50, 107],
89
- [27, 45, 136],
90
- [168, 200, 52],
91
- [7, 102, 27],
92
- [42, 93, 56],
93
- [140, 52, 112],
94
- [92, 107, 168],
95
- [17, 118, 176],
96
- [59, 50, 174],
97
- [206, 40, 143],
98
- [44, 19, 142],
99
- [23, 168, 75],
100
- [54, 57, 189],
101
- [144, 21, 15],
102
- [15, 176, 35],
103
- [107, 19, 79],
104
- [204, 52, 114],
105
- [48, 173, 83],
106
- [11, 120, 53],
107
- [206, 104, 28],
108
- [20, 31, 153],
109
- [27, 21, 93],
110
- [11, 206, 138],
111
- [112, 30, 83],
112
- [68, 91, 152],
113
- [153, 13, 43],
114
- [25, 114, 54],
115
- [92, 27, 150],
116
- [108, 42, 59],
117
- [194, 77, 5],
118
- [145, 48, 83],
119
- [7, 113, 19],
120
- [25, 92, 113],
121
- [60, 168, 79],
122
- [78, 33, 120],
123
- [89, 176, 205],
124
- [27, 200, 94],
125
- [210, 67, 23],
126
- [123, 89, 189],
127
- [225, 56, 112],
128
- [75, 156, 45],
129
- [172, 104, 200],
130
- [15, 170, 197],
131
- [240, 133, 65],
132
- [89, 156, 112],
133
- [214, 88, 57],
134
- [156, 134, 200],
135
- [78, 57, 189],
136
- [200, 78, 123],
137
- [106, 120, 210],
138
- [145, 56, 112],
139
- [89, 120, 189],
140
- [185, 206, 56],
141
- [47, 99, 28],
142
- [112, 189, 78],
143
- [200, 112, 89],
144
- [89, 145, 112],
145
- [78, 106, 189],
146
- [112, 78, 189],
147
- [156, 112, 78],
148
- [28, 210, 99],
149
- [78, 89, 189],
150
- [189, 78, 57],
151
- [112, 200, 78],
152
- [189, 47, 78],
153
- [205, 112, 57],
154
- [78, 145, 57],
155
- [200, 78, 112],
156
- [99, 89, 145],
157
- [200, 156, 78],
158
- [57, 78, 145],
159
- [78, 57, 99],
160
- [57, 78, 145],
161
- [145, 112, 78],
162
- [78, 89, 145],
163
- [210, 99, 28],
164
- [145, 78, 189],
165
- [57, 200, 136],
166
- [89, 156, 78],
167
- [145, 78, 99],
168
- [99, 28, 210],
169
- [189, 78, 47],
170
- [28, 210, 99],
171
- [141, 21, 43],
172
  ]
173
 
 
174
  labels_list = []
175
 
176
  with open(r'labels.txt', 'r') as fp:
@@ -179,6 +181,7 @@ with open(r'labels.txt', 'r') as fp:
179
 
180
  colormap = np.asarray(ade_palette())
181
 
 
182
  def label_to_color_image(label):
183
  if label.ndim != 2:
184
  raise ValueError("Expect 2-D input label")
@@ -187,6 +190,7 @@ def label_to_color_image(label):
187
  raise ValueError("label value too large.")
188
  return colormap[label]
189
 
 
190
  def draw_plot(pred_img, seg):
191
  fig = plt.figure(figsize=(20, 15))
192
 
@@ -208,6 +212,7 @@ def draw_plot(pred_img, seg):
208
  ax.tick_params(width=0.0, labelsize=25)
209
  return fig
210
 
 
211
  def sepia(input_img):
212
  input_img = Image.fromarray(input_img)
213
 
@@ -234,11 +239,11 @@ def sepia(input_img):
234
  fig = draw_plot(pred_img, seg)
235
  return fig
236
 
 
237
  demo = gr.Interface(fn=sepia,
238
  inputs=gr.Image(shape=(400, 600)),
239
  outputs=['plot'],
240
  examples=["image-1.jpg", "image-2.jpg", "image-3.jpg", "image-4.jpeg", "image-5.jpg"],
241
  allow_flagging='never')
242
 
243
-
244
  demo.launch()
 
5
  import numpy as np
6
  from PIL import Image
7
  import tensorflow as tf
8
+ from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
 
9
 
10
+ import requests
11
 
12
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
13
  "nvidia/segformer-b5-finetuned-ade-640-640"
14
  )
15
+ model = TFSegformerForSemanticSegmentation.from_pretrained(
16
  "nvidia/segformer-b5-finetuned-ade-640-640"
17
  )
18
 
19
+
20
  def ade_palette():
21
  """ADE20K palette that maps each class to RGB values."""
22
  return [
23
+ [215, 252, 54],
24
+ [219, 99, 20],
25
+ [30, 125, 246],
26
+ [21, 211, 22],
27
+ [117, 165, 201],
28
+ [122, 2, 6],
29
+ [52, 144, 140],
30
+ [136, 36, 114],
31
+ [208, 249, 44],
32
+ [210, 245, 157],
33
+ [48, 222, 84],
34
+ [175, 182, 112],
35
+ [117, 9, 240],
36
+ [153, 38, 30],
37
+ [75, 225, 231],
38
+ [232, 170, 70],
39
+ [154, 35, 115],
40
+ [45, 61, 35],
41
+ [73, 144, 2],
42
+ [54, 80, 136],
43
+ [143, 200, 212],
44
+ [75, 104, 98],
45
+ [17, 211, 27],
46
+ [205, 195, 241],
47
+ [234, 251, 104],
48
+ [33, 174, 95],
49
+ [160, 174, 99],
50
+ [141, 26, 157],
51
+ [84, 247, 88],
52
+ [19, 248, 198],
53
+ [4, 217, 155],
54
+ [204, 163, 16],
55
+ [148, 209, 143],
56
+ [211, 97, 65],
57
+ [19, 4, 131],
58
+ [40, 196, 45],
59
+ [39, 64, 20],
60
+ [166, 107, 50],
61
+ [108, 103, 78],
62
+ [188, 11, 213],
63
+ [24, 156, 152],
64
+ [230, 162, 223],
65
+ [30, 126, 220],
66
+ [74, 10, 238],
67
+ [186, 128, 227],
68
+ [83, 188, 220],
69
+ [9, 132, 231],
70
+ [96, 99, 79],
71
+ [196, 139, 187],
72
+ [117, 122, 171],
73
+ [0, 156, 220],
74
+ [243, 249, 189],
75
+ [243, 245, 211],
76
+ [103, 146, 83],
77
+ [237, 144, 197],
78
+ [35, 151, 20],
79
+ [15, 61, 139],
80
+ [78, 223, 132],
81
+ [120, 49, 9],
82
+ [67, 160, 234],
83
+ [183, 244, 210],
84
+ [245, 161, 139],
85
+ [57, 70, 189],
86
+ [105, 150, 31],
87
+ [219, 85, 49],
88
+ [206, 81, 97],
89
+ [30, 171, 92],
90
+ [251, 42, 67],
91
+ [121, 183, 220],
92
+ [221, 33, 43],
93
+ [8, 96, 100],
94
+ [76, 149, 53],
95
+ [29, 201, 129],
96
+ [7, 213, 227],
97
+ [143, 93, 153],
98
+ [205, 35, 110],
99
+ [37, 94, 142],
100
+ [131, 157, 110],
101
+ [215, 166, 147],
102
+ [164, 94, 252],
103
+ [179, 108, 233],
104
+ [35, 157, 209],
105
+ [145, 252, 241],
106
+ [155, 60, 40],
107
+ [70, 25, 44],
108
+ [53, 83, 133],
109
+ [150, 42, 191],
110
+ [142, 245, 58],
111
+ [150, 198, 69],
112
+ [0, 139, 86],
113
+ [123, 212, 143],
114
+ [210, 166, 191],
115
+ [148, 194, 130],
116
+ [35, 213, 154],
117
+ [203, 139, 93],
118
+ [59, 86, 45],
119
+ [9, 50, 169],
120
+ [207, 118, 246],
121
+ [200, 82, 65],
122
+ [37, 75, 120],
123
+ [237, 99, 63],
124
+ [168, 145, 190],
125
+ [225, 48, 16],
126
+ [17, 184, 115],
127
+ [224, 124, 15],
128
+ [148, 167, 47],
129
+ [162, 25, 116],
130
+ [154, 90, 36],
131
+ [185, 247, 43],
132
+ [183, 138, 202],
133
+ [64, 96, 117],
134
+ [187, 140, 140],
135
+ [121, 116, 188],
136
+ [252, 251, 162],
137
+ [85, 50, 40],
138
+ [209, 241, 228],
139
+ [30, 41, 95],
140
+ [246, 217, 64],
141
+ [151, 149, 197],
142
+ [117, 42, 205],
143
+ [26, 248, 30],
144
+ [28, 224, 232],
145
+ [228, 89, 96],
146
+ [198, 44, 113],
147
+ [220, 68, 218],
148
+ [59, 85, 210],
149
+ [24, 230, 191],
150
+ [145, 192, 181],
151
+ [132, 189, 92],
152
+ [47, 29, 128],
153
+ [11, 245, 204],
154
+ [182, 79, 207],
155
+ [42, 64, 187],
156
+ [72, 4, 37],
157
+ [105, 67, 133],
158
+ [86, 27, 200],
159
+ [243, 211, 40],
160
+ [150, 136, 40],
161
+ [3, 192, 172],
162
+ [34, 96, 149],
163
+ [32, 108, 56],
164
+ [128, 10, 137],
165
+ [94, 211, 108],
166
+ [78, 250, 243],
167
+ [6, 74, 205],
168
+ [6, 7, 38],
169
+ [161, 26, 40],
170
+ [145, 254, 27],
171
+ [119, 145, 127],
172
+ [13, 82, 153],
173
  ]
174
 
175
+
176
  labels_list = []
177
 
178
  with open(r'labels.txt', 'r') as fp:
 
181
 
182
  colormap = np.asarray(ade_palette())
183
 
184
+
185
  def label_to_color_image(label):
186
  if label.ndim != 2:
187
  raise ValueError("Expect 2-D input label")
 
190
  raise ValueError("label value too large.")
191
  return colormap[label]
192
 
193
+
194
  def draw_plot(pred_img, seg):
195
  fig = plt.figure(figsize=(20, 15))
196
 
 
212
  ax.tick_params(width=0.0, labelsize=25)
213
  return fig
214
 
215
+
216
  def sepia(input_img):
217
  input_img = Image.fromarray(input_img)
218
 
 
239
  fig = draw_plot(pred_img, seg)
240
  return fig
241
 
242
+
243
  demo = gr.Interface(fn=sepia,
244
  inputs=gr.Image(shape=(400, 600)),
245
  outputs=['plot'],
246
  examples=["image-1.jpg", "image-2.jpg", "image-3.jpg", "image-4.jpeg", "image-5.jpg"],
247
  allow_flagging='never')
248
 
 
249
  demo.launch()