chongzhou commited on
Commit
4c7dc89
·
1 Parent(s): 3344f75

solve the concurrent visiting problem with gradio session

Browse files
Files changed (1) hide show
  1. app.py +105 -126
app.py CHANGED
@@ -100,63 +100,41 @@ description_b = """ # Instructions for box mode
100
 
101
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
102
 
103
- global_points = []
104
- global_point_label = []
105
- global_box = []
106
- global_image = None
107
- global_image_with_prompt = None
108
-
109
-
110
- def reset():
111
- global global_points
112
- global global_point_label
113
- global global_box
114
- global global_image
115
- global global_image_with_prompt
116
- global_points = []
117
- global_point_label = []
118
- global_box = []
119
- global_image = None
120
- global_image_with_prompt = None
121
- return None
122
-
123
-
124
- def reset_all():
125
- global global_points
126
- global global_point_label
127
- global global_box
128
- global global_image
129
- global global_image_with_prompt
130
- global_points = []
131
- global_point_label = []
132
- global_box = []
133
- global_image = None
134
- global_image_with_prompt = None
135
- return None, None
136
-
137
-
138
- def clear():
139
- global global_points
140
- global global_point_label
141
- global global_box
142
- global global_image
143
- global global_image_with_prompt
144
- global_points = []
145
- global_point_label = []
146
- global_box = []
147
- global_image_with_prompt = copy.deepcopy(global_image)
148
- return global_image
149
-
150
-
151
- def on_image_upload(image, input_size=1024):
152
- global global_points
153
- global global_point_label
154
- global global_box
155
- global global_image
156
- global global_image_with_prompt
157
- global_points = []
158
- global_point_label = []
159
- global_box = []
160
 
161
  input_size = int(input_size)
162
  w, h = image.size
@@ -164,13 +142,13 @@ def on_image_upload(image, input_size=1024):
164
  new_w = int(w * scale)
165
  new_h = int(h * scale)
166
  image = image.resize((new_w, new_h))
167
- global_image = copy.deepcopy(image)
168
- global_image_with_prompt = copy.deepcopy(image)
169
  print("Image changed")
170
  # nd_image = np.array(global_image)
171
  # predictor.set_image(nd_image)
172
 
173
- return image
174
 
175
 
176
  def convert_box(xyxy):
@@ -186,52 +164,48 @@ def convert_box(xyxy):
186
 
187
 
188
  def segment_with_points(
189
- label,
190
- evt: gr.SelectData,
191
- input_size=1024,
192
- better_quality=False,
193
- withContours=True,
194
- use_retina=True,
195
- mask_random_color=False,
 
196
  ):
197
- global global_points
198
- global global_point_label
199
- global global_image
200
- global global_image_with_prompt
201
-
202
  x, y = evt.index[0], evt.index[1]
203
  point_radius, point_color = 5, (97, 217, 54) if label == "Positive" else (237, 34, 13)
204
- global_points.append([x, y])
205
- global_point_label.append(1 if label == "Positive" else 0)
206
 
207
- print(f'global_points: {global_points}')
208
- print(f'global_point_label: {global_point_label}')
209
 
210
- draw = ImageDraw.Draw(global_image_with_prompt)
211
  draw.ellipse(
212
  [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
213
  fill=point_color,
214
  )
215
- image = global_image_with_prompt
216
 
217
- nd_image = np.array(global_image)
218
  predictor.set_image(nd_image)
219
 
220
  if ENABLE_ONNX:
221
- global_points_np = np.array(global_points)[None]
222
- global_point_label_np = np.array(global_point_label)[None]
223
  masks, scores, _ = predictor.predict(
224
- point_coords=global_points_np,
225
- point_labels=global_point_label_np,
226
  )
227
  masks = masks.squeeze(0)
228
  scores = scores.squeeze(0)
229
  else:
230
- global_points_np = np.array(global_points)
231
- global_point_label_np = np.array(global_point_label)
232
  masks, scores, logits = predictor.predict(
233
- point_coords=global_points_np,
234
- point_labels=global_point_label_np,
235
  num_multimask_outputs=4,
236
  use_stability_score=True
237
  )
@@ -254,10 +228,11 @@ def segment_with_points(
254
  withContours=withContours,
255
  )
256
 
257
- return seg
258
 
259
 
260
  def segment_with_box(
 
261
  evt: gr.SelectData,
262
  input_size=1024,
263
  better_quality=False,
@@ -265,46 +240,41 @@ def segment_with_box(
265
  use_retina=True,
266
  mask_random_color=False,
267
  ):
268
- global global_box
269
- global global_image
270
- global global_image
271
- global global_image_with_prompt
272
-
273
  x, y = evt.index[0], evt.index[1]
274
  point_radius, point_color, box_outline = 5, (97, 217, 54), 5
275
  box_color = (0, 255, 0)
276
 
277
- if len(global_box) == 0:
278
- global_box.append([x, y])
279
- elif len(global_box) == 1:
280
- global_box.append([x, y])
281
- elif len(global_box) == 2:
282
- global_image_with_prompt = copy.deepcopy(global_image)
283
- global_box = [[x, y]]
284
 
285
- print(f'global_box: {global_box}')
286
 
287
- draw = ImageDraw.Draw(global_image_with_prompt)
288
  draw.ellipse(
289
  [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
290
  fill=point_color,
291
  )
292
- image = global_image_with_prompt
293
 
294
- if len(global_box) == 2:
295
- global_box = convert_box(global_box)
296
- xy = (global_box[0][0], global_box[0][1], global_box[1][0], global_box[1][1])
297
  draw.rectangle(
298
  xy,
299
  outline=box_color,
300
  width=box_outline
301
  )
302
 
303
- global_box_np = np.array(global_box)
304
- nd_image = np.array(global_image)
305
  predictor.set_image(nd_image)
306
  if ENABLE_ONNX:
307
- point_coords = global_box_np.reshape(2, 2)[None]
308
  point_labels = np.array([2, 3])[None]
309
  masks, _, _ = predictor.predict(
310
  point_coords=point_coords,
@@ -313,7 +283,7 @@ def segment_with_box(
313
  annotations = masks[:, 0, :, :]
314
  else:
315
  masks, scores, _ = predictor.predict(
316
- box=global_box_np,
317
  num_multimask_outputs=1,
318
  )
319
  annotations = masks
@@ -329,13 +299,22 @@ def segment_with_box(
329
  use_retina=use_retina,
330
  withContours=withContours,
331
  )
332
- return seg
333
- return image
 
334
 
335
  img_p = gr.Image(label="Input with points", type="pil")
336
  img_b = gr.Image(label="Input with box", type="pil")
337
 
338
  with gr.Blocks(css=css, title="EdgeSAM") as demo:
 
 
 
 
 
 
 
 
339
  with gr.Row():
340
  with gr.Column(scale=1):
341
  # Title
@@ -365,8 +344,8 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
365
  gr.Markdown("Try some of the examples below ⬇️")
366
  gr.Examples(
367
  examples=examples,
368
- inputs=[img_p],
369
- outputs=[img_p],
370
  examples_per_page=8,
371
  fn=on_image_upload,
372
  run_on_click=True
@@ -388,8 +367,8 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
388
  gr.Markdown("Try some of the examples below ⬇️")
389
  gr.Examples(
390
  examples=examples,
391
- inputs=[img_b],
392
- outputs=[img_b],
393
  examples_per_page=8,
394
  fn=on_image_upload,
395
  run_on_click=True
@@ -400,19 +379,19 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
400
  gr.Markdown(
401
  "<center><img src='https://visitor-badge.laobi.icu/badge?page_id=chongzhou/edgesam' alt='visitors'></center>")
402
 
403
- img_p.upload(on_image_upload, img_p, [img_p])
404
- img_p.select(segment_with_points, [add_or_remove], img_p)
405
 
406
- clear_btn_p.click(clear, outputs=[img_p])
407
- reset_btn_p.click(reset, outputs=[img_p])
408
- tab_p.select(fn=reset_all, outputs=[img_p, img_b])
409
 
410
- img_b.upload(on_image_upload, img_b, [img_b])
411
- img_b.select(segment_with_box, outputs=[img_b])
412
 
413
- clear_btn_b.click(clear, outputs=[img_b])
414
- reset_btn_b.click(reset, outputs=[img_b])
415
- tab_b.select(fn=reset_all, outputs=[img_p, img_b])
416
 
417
  demo.queue()
418
  # demo.launch(server_name=args.server_name, server_port=args.port)
 
100
 
101
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
102
 
103
+
104
+ def reset(session_state):
105
+ session_state['coord_list'] = []
106
+ session_state['label_list'] = []
107
+ session_state['box_list'] = []
108
+ session_state['ori_image'] = None
109
+ session_state['image_with_prompt'] = None
110
+ return None, session_state
111
+
112
+
113
+ def reset_all(session_state):
114
+ session_state['coord_list'] = []
115
+ session_state['label_list'] = []
116
+ session_state['box_list'] = []
117
+ session_state['ori_image'] = None
118
+ session_state['image_with_prompt'] = None
119
+ return None, None, session_state
120
+
121
+
122
+ def clear(session_state):
123
+ session_state['coord_list'] = []
124
+ session_state['label_list'] = []
125
+ session_state['box_list'] = []
126
+ session_state['image_with_prompt'] = copy.deepcopy(session_state['ori_image'])
127
+ return session_state['ori_image'], session_state
128
+
129
+
130
+ def on_image_upload(
131
+ image,
132
+ session_state,
133
+ input_size=1024
134
+ ):
135
+ session_state['coord_list'] = []
136
+ session_state['label_list'] = []
137
+ session_state['box_list'] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  input_size = int(input_size)
140
  w, h = image.size
 
142
  new_w = int(w * scale)
143
  new_h = int(h * scale)
144
  image = image.resize((new_w, new_h))
145
+ session_state['ori_image'] = copy.deepcopy(image)
146
+ session_state['image_with_prompt'] = copy.deepcopy(image)
147
  print("Image changed")
148
  # nd_image = np.array(global_image)
149
  # predictor.set_image(nd_image)
150
 
151
+ return image, session_state
152
 
153
 
154
  def convert_box(xyxy):
 
164
 
165
 
166
  def segment_with_points(
167
+ label,
168
+ session_state,
169
+ evt: gr.SelectData,
170
+ input_size=1024,
171
+ better_quality=False,
172
+ withContours=True,
173
+ use_retina=True,
174
+ mask_random_color=False,
175
  ):
 
 
 
 
 
176
  x, y = evt.index[0], evt.index[1]
177
  point_radius, point_color = 5, (97, 217, 54) if label == "Positive" else (237, 34, 13)
178
+ session_state['coord_list'].append([x, y])
179
+ session_state['label_list'].append(1 if label == "Positive" else 0)
180
 
181
+ print(f"coord_list: {session_state['coord_list']}")
182
+ print(f"label_list: {session_state['label_list']}")
183
 
184
+ draw = ImageDraw.Draw(session_state['image_with_prompt'])
185
  draw.ellipse(
186
  [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
187
  fill=point_color,
188
  )
189
+ image = session_state['image_with_prompt']
190
 
191
+ nd_image = np.array(session_state['ori_image'])
192
  predictor.set_image(nd_image)
193
 
194
  if ENABLE_ONNX:
195
+ coord_np = np.array(session_state['coord_list'])[None]
196
+ label_np = np.array(session_state['label_list'])[None]
197
  masks, scores, _ = predictor.predict(
198
+ point_coords=coord_np,
199
+ point_labels=label_np,
200
  )
201
  masks = masks.squeeze(0)
202
  scores = scores.squeeze(0)
203
  else:
204
+ coord_np = np.array(session_state['coord_list'])
205
+ label_np = np.array(session_state['label_list'])
206
  masks, scores, logits = predictor.predict(
207
+ point_coords=coord_np,
208
+ point_labels=label_np,
209
  num_multimask_outputs=4,
210
  use_stability_score=True
211
  )
 
228
  withContours=withContours,
229
  )
230
 
231
+ return seg, session_state
232
 
233
 
234
  def segment_with_box(
235
+ session_state,
236
  evt: gr.SelectData,
237
  input_size=1024,
238
  better_quality=False,
 
240
  use_retina=True,
241
  mask_random_color=False,
242
  ):
 
 
 
 
 
243
  x, y = evt.index[0], evt.index[1]
244
  point_radius, point_color, box_outline = 5, (97, 217, 54), 5
245
  box_color = (0, 255, 0)
246
 
247
+ if len(session_state['box_list']) == 0:
248
+ session_state['box_list'].append([x, y])
249
+ elif len(session_state['box_list']) == 1:
250
+ session_state['box_list'].append([x, y])
251
+ elif len(session_state['box_list']) == 2:
252
+ session_state['image_with_prompt'] = copy.deepcopy(session_state['ori_image'])
253
+ session_state['box_list'] = [[x, y]]
254
 
255
+ print(f"box_list: {session_state['box_list']}")
256
 
257
+ draw = ImageDraw.Draw(session_state['image_with_prompt'])
258
  draw.ellipse(
259
  [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
260
  fill=point_color,
261
  )
262
+ image = session_state['image_with_prompt']
263
 
264
+ if len(session_state['box_list']) == 2:
265
+ box = convert_box(session_state['box_list'])
266
+ xy = (box[0][0], box[0][1], box[1][0], box[1][1])
267
  draw.rectangle(
268
  xy,
269
  outline=box_color,
270
  width=box_outline
271
  )
272
 
273
+ box_np = np.array(box)
274
+ nd_image = np.array(session_state['ori_image'])
275
  predictor.set_image(nd_image)
276
  if ENABLE_ONNX:
277
+ point_coords = box_np.reshape(2, 2)[None]
278
  point_labels = np.array([2, 3])[None]
279
  masks, _, _ = predictor.predict(
280
  point_coords=point_coords,
 
283
  annotations = masks[:, 0, :, :]
284
  else:
285
  masks, scores, _ = predictor.predict(
286
+ box=box_np,
287
  num_multimask_outputs=1,
288
  )
289
  annotations = masks
 
299
  use_retina=use_retina,
300
  withContours=withContours,
301
  )
302
+ return seg, session_state
303
+ return image, session_state
304
+
305
 
306
  img_p = gr.Image(label="Input with points", type="pil")
307
  img_b = gr.Image(label="Input with box", type="pil")
308
 
309
  with gr.Blocks(css=css, title="EdgeSAM") as demo:
310
+ session_state = gr.State({
311
+ 'coord_list': [],
312
+ 'label_list': [],
313
+ 'box_list': [],
314
+ 'ori_image': None,
315
+ 'image_with_prompt': None
316
+ })
317
+
318
  with gr.Row():
319
  with gr.Column(scale=1):
320
  # Title
 
344
  gr.Markdown("Try some of the examples below ⬇️")
345
  gr.Examples(
346
  examples=examples,
347
+ inputs=[img_p, session_state],
348
+ outputs=[img_p, session_state],
349
  examples_per_page=8,
350
  fn=on_image_upload,
351
  run_on_click=True
 
367
  gr.Markdown("Try some of the examples below ⬇️")
368
  gr.Examples(
369
  examples=examples,
370
+ inputs=[img_b, session_state],
371
+ outputs=[img_b, session_state],
372
  examples_per_page=8,
373
  fn=on_image_upload,
374
  run_on_click=True
 
379
  gr.Markdown(
380
  "<center><img src='https://visitor-badge.laobi.icu/badge?page_id=chongzhou/edgesam' alt='visitors'></center>")
381
 
382
+ img_p.upload(on_image_upload, [img_p, session_state], [img_p, session_state])
383
+ img_p.select(segment_with_points, [add_or_remove, session_state], [img_p, session_state])
384
 
385
+ clear_btn_p.click(clear, [session_state], [img_p, session_state])
386
+ reset_btn_p.click(reset, [session_state], [img_p, session_state])
387
+ tab_p.select(fn=reset_all, inputs=[session_state], outputs=[img_p, img_b, session_state])
388
 
389
+ img_b.upload(on_image_upload, [img_b, session_state], [img_b, session_state])
390
+ img_b.select(segment_with_box, [session_state], [img_b, session_state])
391
 
392
+ clear_btn_b.click(clear, [session_state], [img_b, session_state])
393
+ reset_btn_b.click(reset, [session_state], [img_b, session_state])
394
+ tab_b.select(fn=reset_all, inputs=[session_state], outputs=[img_p, img_b, session_state])
395
 
396
  demo.queue()
397
  # demo.launch(server_name=args.server_name, server_port=args.port)