LPX55 commited on
Commit
db5e884
·
verified ·
1 Parent(s): 8d48719

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -36
app.py CHANGED
@@ -8,43 +8,35 @@ from PIL import Image
8
  from sam2.build_sam import build_sam2
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
 
 
 
11
  def preprocess_image(image):
12
  return image, gr.State([]), gr.State([]), image
13
 
14
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
15
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
16
-
17
  tracking_points.value.append(evt.index)
18
  print(f"TRACKING POINT: {tracking_points.value}")
19
-
20
  if point_type == "include":
21
  trackings_input_label.value.append(1)
22
  elif point_type == "exclude":
23
  trackings_input_label.value.append(0)
24
  print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
25
-
26
  transparent_background = Image.open(first_frame_path).convert('RGBA')
27
  w, h = transparent_background.size
28
-
29
  fraction = 0.02
30
  radius = int(fraction * min(w, h))
31
-
32
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
33
-
34
  for index, track in enumerate(tracking_points.value):
35
  if trackings_input_label.value[index] == 1:
36
  cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
37
  else:
38
  cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
39
-
40
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
41
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
42
-
43
  return tracking_points, trackings_input_label, selected_point_map
44
 
45
- # Remove all CUDA-specific configurations
46
- torch.autocast(device_type="cpu", dtype=torch.float32).__enter__()
47
-
48
  def show_mask(mask, ax, random_color=False, borders=True):
49
  if random_color:
50
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
@@ -54,9 +46,9 @@ def show_mask(mask, ax, random_color=False, borders=True):
54
  mask = mask.astype(np.uint8)
55
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
56
  if borders:
57
- contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
58
  contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
59
- mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
60
  ax.imshow(mask_image)
61
 
62
  def show_points(coords, labels, ax, marker_size=375):
@@ -73,65 +65,82 @@ def show_box(box, ax):
73
  def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
74
  combined_images = []
75
  mask_images = []
76
-
77
  for i, (mask, score) in enumerate(zip(masks, scores)):
78
  plt.figure(figsize=(10, 10))
79
  plt.imshow(image)
80
  show_mask(mask, plt.gca(), borders=borders)
81
  plt.axis('off')
82
-
83
  combined_filename = f"combined_image_{i+1}.jpg"
84
  plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
85
  combined_images.append(combined_filename)
86
  plt.close()
87
-
88
  mask_image = np.zeros_like(image, dtype=np.uint8)
89
  mask_layer = (mask > 0).astype(np.uint8) * 255
90
  for c in range(3):
91
  mask_image[:, :, c] = mask_layer
92
-
93
  mask_filename = f"mask_image_{i+1}.png"
94
  Image.fromarray(mask_image).save(mask_filename)
95
  mask_images.append(mask_filename)
96
-
97
  return combined_images, mask_images
98
 
99
- def sam_process(input_image, checkpoint, tracking_points, trackings_input_label):
100
- image = Image.open(input_image)
101
- image = np.array(image.convert("RGB"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
 
103
  checkpoint_map = {
104
  "tiny": ("./checkpoints/sam2_hiera_tiny.pt", "sam2_hiera_t.yaml"),
105
  "small": ("./checkpoints/sam2_hiera_small.pt", "sam2_hiera_s.yaml"),
106
  "base-plus": ("./checkpoints/sam2_hiera_base_plus.pt", "sam2_hiera_b+.yaml"),
107
  "large": ("./checkpoints/sam2_hiera_large.pt", "sam2_hiera_l.yaml")
108
  }
109
-
110
  sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
111
-
112
  # Use CPU for both model and computations
113
  sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
114
  predictor = SAM2ImagePredictor(sam2_model)
115
  predictor.set_image(image)
116
-
117
  input_point = np.array(tracking_points.value)
118
  input_label = np.array(trackings_input_label.value)
119
-
120
  masks, scores, logits = predictor.predict(
121
  point_coords=input_point,
122
  point_labels=input_label,
123
  multimask_output=False,
124
  )
125
-
126
  sorted_ind = np.argsort(scores)[::-1]
127
  masks = masks[sorted_ind]
128
  scores = scores[sorted_ind]
129
-
130
- results, mask_results = show_masks(image, masks, scores,
131
- point_coords=input_point,
132
- input_labels=input_label,
 
 
 
133
  borders=True)
134
-
135
  return results[0], mask_results[0]
136
 
137
  with gr.Blocks() as demo:
@@ -149,36 +158,43 @@ with gr.Blocks() as demo:
149
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
150
  clear_points_btn = gr.Button("Clear Points")
151
  checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="base-plus")
 
 
 
 
 
 
152
  submit_btn = gr.Button("Submit")
153
  with gr.Column():
154
  output_result = gr.Image()
155
  output_result_mask = gr.Image()
156
-
157
  clear_points_btn.click(
158
  fn=preprocess_image,
159
  inputs=input_image,
160
  outputs=[first_frame_path, tracking_points, trackings_input_label, points_map],
161
  queue=False
162
  )
163
-
164
  points_map.upload(
165
  fn=preprocess_image,
166
  inputs=[points_map],
167
  outputs=[first_frame_path, tracking_points, trackings_input_label, input_image],
168
  queue=False
169
  )
170
-
171
  points_map.select(
172
  fn=get_point,
173
  inputs=[point_type, tracking_points, trackings_input_label, first_frame_path],
174
  outputs=[tracking_points, trackings_input_label, points_map],
175
  queue=False
176
  )
177
-
178
  submit_btn.click(
179
  fn=sam_process,
180
- inputs=[input_image, checkpoint, tracking_points, trackings_input_label],
181
  outputs=[output_result, output_result_mask]
182
  )
 
 
 
 
 
183
 
184
  demo.launch(show_api=False, show_error=True)
 
8
  from sam2.build_sam import build_sam2
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
11
+ # Remove all CUDA-specific configurations
12
+ torch.autocast(device_type="cpu", dtype=torch.float32).__enter__()
13
+
14
  def preprocess_image(image):
15
  return image, gr.State([]), gr.State([]), image
16
 
17
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
18
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
 
19
  tracking_points.value.append(evt.index)
20
  print(f"TRACKING POINT: {tracking_points.value}")
 
21
  if point_type == "include":
22
  trackings_input_label.value.append(1)
23
  elif point_type == "exclude":
24
  trackings_input_label.value.append(0)
25
  print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
 
26
  transparent_background = Image.open(first_frame_path).convert('RGBA')
27
  w, h = transparent_background.size
 
28
  fraction = 0.02
29
  radius = int(fraction * min(w, h))
 
30
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
 
31
  for index, track in enumerate(tracking_points.value):
32
  if trackings_input_label.value[index] == 1:
33
  cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
34
  else:
35
  cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
 
36
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
37
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
 
38
  return tracking_points, trackings_input_label, selected_point_map
39
 
 
 
 
40
  def show_mask(mask, ax, random_color=False, borders=True):
41
  if random_color:
42
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
46
  mask = mask.astype(np.uint8)
47
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
48
  if borders:
49
+ contours, _= cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
50
  contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
51
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
52
  ax.imshow(mask_image)
53
 
54
  def show_points(coords, labels, ax, marker_size=375):
 
65
  def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
66
  combined_images = []
67
  mask_images = []
 
68
  for i, (mask, score) in enumerate(zip(masks, scores)):
69
  plt.figure(figsize=(10, 10))
70
  plt.imshow(image)
71
  show_mask(mask, plt.gca(), borders=borders)
72
  plt.axis('off')
 
73
  combined_filename = f"combined_image_{i+1}.jpg"
74
  plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
75
  combined_images.append(combined_filename)
76
  plt.close()
 
77
  mask_image = np.zeros_like(image, dtype=np.uint8)
78
  mask_layer = (mask > 0).astype(np.uint8) * 255
79
  for c in range(3):
80
  mask_image[:, :, c] = mask_layer
 
81
  mask_filename = f"mask_image_{i+1}.png"
82
  Image.fromarray(mask_image).save(mask_filename)
83
  mask_images.append(mask_filename)
 
84
  return combined_images, mask_images
85
 
86
+ def expand_contract_mask(mask, px, expand=True):
87
+ kernel = np.ones((px, px), np.uint8)
88
+ if expand:
89
+ return cv2.dilate(mask, kernel, iterations=1)
90
+ else:
91
+ return cv2.erode(mask, kernel, iterations=1)
92
+
93
+ def feather_mask(mask, feather_size=10):
94
+ feathered_mask = mask.copy()
95
+ Feathered_region = mask > 0
96
+ Feathered_region = cv2.dilate(Feathered_region.astype(np.uint8), np.ones((feather_size, feather_size), np.uint8), iterations=1)
97
+ Feathered_region = Feathered_region & (~mask.astype(bool))
98
+
99
+ for i in range(1, feather_size + 1):
100
+ weight = i / (feather_size + 1)
101
+ feathered_mask[Feathered_region] = feathered_mask[Feathered_region] * (1 - weight) + weight
102
+
103
+ return feathered_mask
104
+
105
+ def process_mask(mask, expand_contract_px, expand, feathering_enabled, feather_size):
106
+ if expand_contract_px > 0:
107
+ mask = expand_contract_mask(mask, expand_contract_px, expand)
108
+ if feathering_enabled:
109
+ mask = feather_mask(mask, feather_size)
110
+ return mask
111
 
112
+ def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
113
+ image = Image.open(input_image)
114
+ image = np.array.array(image.convert("RGB"))
115
  checkpoint_map = {
116
  "tiny": ("./checkpoints/sam2_hiera_tiny.pt", "sam2_hiera_t.yaml"),
117
  "small": ("./checkpoints/sam2_hiera_small.pt", "sam2_hiera_s.yaml"),
118
  "base-plus": ("./checkpoints/sam2_hiera_base_plus.pt", "sam2_hiera_b+.yaml"),
119
  "large": ("./checkpoints/sam2_hiera_large.pt", "sam2_hiera_l.yaml")
120
  }
 
121
  sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
 
122
  # Use CPU for both model and computations
123
  sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
124
  predictor = SAM2ImagePredictor(sam2_model)
125
  predictor.set_image(image)
 
126
  input_point = np.array(tracking_points.value)
127
  input_label = np.array(trackings_input_label.value)
 
128
  masks, scores, logits = predictor.predict(
129
  point_coords=input_point,
130
  point_labels=input_label,
131
  multimask_output=False,
132
  )
 
133
  sorted_ind = np.argsort(scores)[::-1]
134
  masks = masks[sorted_ind]
135
  scores = scores[sorted_ind]
136
+ processed_masks = []
137
+ for mask in masks:
138
+ processed_mask = process_mask(mask, expand_contract_px, expand, feathering_enabled, feather_size)
139
+ processed_masks.append(processed_mask)
140
+ results, mask_results = show_masks(image, processed_masks, scores,
141
+ point_coords=input_point,
142
+ input_labels=input_label,
143
  borders=True)
 
144
  return results[0], mask_results[0]
145
 
146
  with gr.Blocks() as demo:
 
158
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
159
  clear_points_btn = gr.Button("Clear Points")
160
  checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="base-plus")
161
+ with gr.Row():
162
+ expand_contract_px = gr.Slider(minimum=0, maximum=50, default=0, label="Expand/Contract (pixels)")
163
+ expand = gr.Radio(["Expand", "Contract"], default="Expand", label="Action")
164
+ with gr.Row():
165
+ feathering_enabled = gr.Checkbox(default=False, label="Enable Feathering")
166
+ feather_size = gr.Slider(minimum=1, maximum=50, default=10, label="Feathering Size", visible=False)
167
  submit_btn = gr.Button("Submit")
168
  with gr.Column():
169
  output_result = gr.Image()
170
  output_result_mask = gr.Image()
 
171
  clear_points_btn.click(
172
  fn=preprocess_image,
173
  inputs=input_image,
174
  outputs=[first_frame_path, tracking_points, trackings_input_label, points_map],
175
  queue=False
176
  )
 
177
  points_map.upload(
178
  fn=preprocess_image,
179
  inputs=[points_map],
180
  outputs=[first_frame_path, tracking_points, trackings_input_label, input_image],
181
  queue=False
182
  )
 
183
  points_map.select(
184
  fn=get_point,
185
  inputs=[point_type, tracking_points, trackings_input_label, first_frame_path],
186
  outputs=[tracking_points, trackings_input_label, points_map],
187
  queue=False
188
  )
 
189
  submit_btn.click(
190
  fn=sam_process,
191
+ inputs=[input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size],
192
  outputs=[output_result, output_result_mask]
193
  )
194
+ feathering_enabled.change(
195
+ fn=lambda enabled: gr.update(visible=enabled),
196
+ inputs=[feathering_enabled],
197
+ outputs=[feather_size]
198
+ )
199
 
200
  demo.launch(show_api=False, show_error=True)