waleko commited on
Commit
90b2862
·
1 Parent(s): a100834

try first version full

Browse files
Files changed (1) hide show
  1. app.py +42 -35
app.py CHANGED
@@ -17,22 +17,25 @@ model = load_model(config, dino_checkpoint, device)
17
  box_threshold = 0.35
18
  text_threshold = 0.25
19
 
 
20
  def show_mask(mask, ax, random_color=False):
21
  if random_color:
22
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
23
  else:
24
- color = np.array([30/255, 144/255, 255/255, 0.6])
25
  h, w = mask.shape[-2:]
26
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
27
  ax.imshow(mask_image)
28
 
29
- def show_box(box, ax, label = None):
 
30
  x0, y0 = box[0], box[1]
31
  w, h = box[2] - box[0], box[3] - box[1]
32
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0,0,0,0), lw=2))
33
  if label is not None:
34
  ax.text(x0, y0, label, fontsize=12, color='white', backgroundcolor='red', ha='left', va='top')
35
 
 
36
  def extract_object_with_transparent_background(image, masks):
37
  mask_expanded = np.expand_dims(masks[0], axis=-1)
38
  mask_expanded = np.repeat(mask_expanded, 3, axis=-1)
@@ -42,6 +45,7 @@ def extract_object_with_transparent_background(image, masks):
42
  rgba_segment[:, :, 3] = masks[0] * 255
43
  return rgba_segment
44
 
 
45
  def extract_remaining_image(image, masks):
46
  inverse_mask = np.logical_not(masks[0])
47
  inverse_mask_expanded = np.expand_dims(inverse_mask, axis=-1)
@@ -49,6 +53,7 @@ def extract_remaining_image(image, masks):
49
  remaining_image = image * inverse_mask_expanded
50
  return remaining_image
51
 
 
52
  def overlay_masks_boxes_on_image(image, masks, boxes, labels, show_masks, show_boxes):
53
  fig, ax = plt.subplots()
54
  ax.imshow(image)
@@ -64,9 +69,9 @@ def overlay_masks_boxes_on_image(image, masks, boxes, labels, show_masks, show_b
64
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
65
  plt.margins(0, 0)
66
 
67
- fig.canvas.draw()
68
  output_image = np.array(fig.canvas.buffer_rgba())
69
-
70
  plt.close(fig)
71
  return output_image
72
 
@@ -74,12 +79,12 @@ def overlay_masks_boxes_on_image(image, masks, boxes, labels, show_masks, show_b
74
  def detect_objects(image, prompt, show_masks=True, show_boxes=True, crop_options="No crop"):
75
  image_source, image = load_image(image)
76
  predictor.set_image(image_source)
77
-
78
  boxes, logits, phrases = predict(
79
- model=model,
80
- image=image,
81
- caption=prompt,
82
- box_threshold=box_threshold,
83
  text_threshold=text_threshold,
84
  device=device
85
  )
@@ -91,20 +96,23 @@ def detect_objects(image, prompt, show_masks=True, show_boxes=True, crop_options
91
  labels = [f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits)]
92
 
93
  masks_list = []
 
 
 
94
 
95
- for input_box, label in zip(boxes, labels):
96
  x1, y1, x2, y2 = input_box
97
  width = x2 - x1
98
  height = y2 - y1
99
  avg_size = (width + height) / 2
100
- d = avg_size * 0.1
101
-
102
  center_point = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
103
  points = []
104
- points.append([center_point[0], center_point[1] - d])
105
- points.append([center_point[0], center_point[1] + d])
106
- points.append([center_point[0] - d, center_point[1]])
107
- points.append([center_point[0] + d, center_point[1]])
108
  input_point = np.array(points)
109
  input_label = np.array([1] * len(input_point))
110
 
@@ -122,25 +130,24 @@ def detect_objects(image, prompt, show_masks=True, show_boxes=True, crop_options
122
  multimask_output=False
123
  )
124
  masks_list.append(masks)
125
-
126
- if crop_options == "Crop":
127
  composite_image = np.zeros_like(image_source)
128
- for masks in masks_list:
129
- rgba_segment = extract_object_with_transparent_background(image_source, masks)
130
- composite_image = np.maximum(composite_image, rgba_segment[:, :, :3])
131
- output_image = overlay_masks_boxes_on_image(composite_image, masks_list, boxes, labels, show_masks, show_boxes)
132
- elif crop_options == "Inverse Crop":
133
- remaining_image = image_source.copy()
134
- for masks in masks_list:
135
- remaining_image = extract_remaining_image(remaining_image, masks)
136
- output_image = overlay_masks_boxes_on_image(remaining_image, masks_list, boxes, labels, show_masks, show_boxes)
137
- else:
138
- output_image = overlay_masks_boxes_on_image(image_source, masks_list, boxes, labels, show_masks, show_boxes)
139
-
140
- output_image_path = 'output_image.jpeg'
141
- plt.imsave(output_image_path, output_image)
142
-
143
- return [{"tmp": output_image_path}, [output_image_path]]
144
 
145
 
146
  app = gr.Interface(
 
17
  box_threshold = 0.35
18
  text_threshold = 0.25
19
 
20
+
21
  def show_mask(mask, ax, random_color=False):
22
  if random_color:
23
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
24
  else:
25
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
26
  h, w = mask.shape[-2:]
27
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
28
  ax.imshow(mask_image)
29
 
30
+
31
+ def show_box(box, ax, label=None):
32
  x0, y0 = box[0], box[1]
33
  w, h = box[2] - box[0], box[3] - box[1]
34
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2))
35
  if label is not None:
36
  ax.text(x0, y0, label, fontsize=12, color='white', backgroundcolor='red', ha='left', va='top')
37
 
38
+
39
  def extract_object_with_transparent_background(image, masks):
40
  mask_expanded = np.expand_dims(masks[0], axis=-1)
41
  mask_expanded = np.repeat(mask_expanded, 3, axis=-1)
 
45
  rgba_segment[:, :, 3] = masks[0] * 255
46
  return rgba_segment
47
 
48
+
49
  def extract_remaining_image(image, masks):
50
  inverse_mask = np.logical_not(masks[0])
51
  inverse_mask_expanded = np.expand_dims(inverse_mask, axis=-1)
 
53
  remaining_image = image * inverse_mask_expanded
54
  return remaining_image
55
 
56
+
57
  def overlay_masks_boxes_on_image(image, masks, boxes, labels, show_masks, show_boxes):
58
  fig, ax = plt.subplots()
59
  ax.imshow(image)
 
69
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
70
  plt.margins(0, 0)
71
 
72
+ fig.canvas.draw()
73
  output_image = np.array(fig.canvas.buffer_rgba())
74
+
75
  plt.close(fig)
76
  return output_image
77
 
 
79
  def detect_objects(image, prompt, show_masks=True, show_boxes=True, crop_options="No crop"):
80
  image_source, image = load_image(image)
81
  predictor.set_image(image_source)
82
+
83
  boxes, logits, phrases = predict(
84
+ model=model,
85
+ image=image,
86
+ caption=prompt,
87
+ box_threshold=box_threshold,
88
  text_threshold=text_threshold,
89
  device=device
90
  )
 
96
  labels = [f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits)]
97
 
98
  masks_list = []
99
+ res_json = {"prompt": prompt, "objects": []}
100
+
101
+ output_image_paths = []
102
 
103
+ for i, (input_box, label) in enumerate(zip(boxes, labels)):
104
  x1, y1, x2, y2 = input_box
105
  width = x2 - x1
106
  height = y2 - y1
107
  avg_size = (width + height) / 2
108
+ d = avg_size * 0.1
109
+
110
  center_point = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
111
  points = []
112
+ points.append([center_point[0], center_point[1] - d])
113
+ points.append([center_point[0], center_point[1] + d])
114
+ points.append([center_point[0] - d, center_point[1]])
115
+ points.append([center_point[0] + d, center_point[1]])
116
  input_point = np.array(points)
117
  input_label = np.array([1] * len(input_point))
118
 
 
130
  multimask_output=False
131
  )
132
  masks_list.append(masks)
133
+
 
134
  composite_image = np.zeros_like(image_source)
135
+ rgba_segment = extract_object_with_transparent_background(image_source, masks)
136
+ composite_image = np.maximum(composite_image, rgba_segment[:, :, :3])
137
+ cropped_image = composite_image[y1:y2, x1:x2, :]
138
+ output_image = overlay_masks_boxes_on_image(cropped_image, [], [], [], False, False)
139
+
140
+ output_image_path = f'output_image_{i}.jpeg'
141
+ plt.imsave(output_image_path, output_image)
142
+
143
+ output_image_paths.append(output_image_path)
144
+
145
+ # save object information in json
146
+ res_json["objects"].append(
147
+ {"label": label, "score": np.max(scores), "box": input_box.tolist(), "center": center_point.tolist(),
148
+ "avg_size": avg_size})
149
+
150
+ return [res_json, output_image_paths]
151
 
152
 
153
  app = gr.Interface(