hugohabicht01 commited on
Commit
d528373
·
1 Parent(s): a894bc6

download sam from huggingface

Browse files
Files changed (1) hide show
  1. blurnonymize.py +14 -44
blurnonymize.py CHANGED
@@ -7,7 +7,6 @@ import matplotlib.patches as patches
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import torch
10
- from sam2.build_sam import build_sam2
11
  from sam2.sam2_image_predictor import SAM2ImagePredictor
12
  from utils import *
13
 
@@ -160,38 +159,8 @@ class ImageBlurnonymizer:
160
  if self.predictor is not None and not force:
161
  return
162
 
163
- self.download_weights()
164
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
165
- sam = build_sam2(f"file://{self.model_cfg_name}", f"file://{self.checkpoint_name}", device=self.device)
166
- self.predictor = SAM2ImagePredictor(sam)
167
-
168
- def download_weights(self):
169
- # TODO: check whether these files already exist, if not, download them
170
- # files names are in self.checkpoint_name and self.model_cfg_name
171
-
172
- print(f"Current working directory: {os.getcwd()}")
173
- print(f"Files in current dir: {os.listdir()}")
174
- print(f"Files in root dir: {os.listdir('/')}")
175
-
176
- print(f"Files in /data: {os.listdir('/data')}")
177
-
178
- checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt"
179
- cfg_url = "https://raw.githubusercontent.com/facebookresearch/sam2/refs/heads/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml"
180
-
181
- config_dir = "/data/configs/sam2.1"
182
-
183
- if not os.path.exists(config_dir):
184
- os.makedirs(config_dir)
185
-
186
- if not os.path.exists(self.checkpoint_name):
187
- print(
188
- f"Downloading checkpoint from {checkpoint_url} to {self.checkpoint_name}"
189
- )
190
- torch.hub.download_url_to_file(checkpoint_url, self.checkpoint_name)
191
-
192
- if not os.path.exists(self.model_cfg_name):
193
- print(f"Downloading config from {cfg_url} to {self.model_cfg_name}")
194
- torch.hub.download_url_to_file(cfg_url, self.model_cfg_name)
195
 
196
  @staticmethod
197
  def _smoothen_mask(mask: np.ndarray):
@@ -276,20 +245,21 @@ class ImageBlurnonymizer:
276
  # Ensure points are valid coordinates (e.g., non-negative)
277
  points = [[max(0, p[0]), max(0, p[1])] for p in points]
278
 
279
- self.predictor.set_image(image)
280
- masks, scores, _ = self.predictor.predict(
281
- box=np.array(bbox), # Predictor might expect numpy array
282
- point_coords=np.array(points),
283
- point_labels=np.ones(len(points)), # Label 1 for inclusion
284
- multimask_output=True,
285
- )
 
286
 
287
- # Sort masks by score and select the best one
288
- sorted_ind = np.argsort(scores)[::-1]
289
- best_mask = masks[sorted_ind[0]]
290
- best_score = scores[sorted_ind[0]]
291
 
292
- return self._smoothen_mask(best_mask), best_score
293
 
294
  def censor_image_blur(
295
  self,
 
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import torch
 
10
  from sam2.sam2_image_predictor import SAM2ImagePredictor
11
  from utils import *
12
 
 
159
  if self.predictor is not None and not force:
160
  return
161
 
 
162
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
163
+ self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  @staticmethod
166
  def _smoothen_mask(mask: np.ndarray):
 
245
  # Ensure points are valid coordinates (e.g., non-negative)
246
  points = [[max(0, p[0]), max(0, p[1])] for p in points]
247
 
248
+ with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
249
+ self.predictor.set_image(image)
250
+ masks, scores, _ = self.predictor.predict(
251
+ box=np.array(bbox), # Predictor might expect numpy array
252
+ point_coords=np.array(points),
253
+ point_labels=np.ones(len(points)), # Label 1 for inclusion
254
+ multimask_output=True,
255
+ )
256
 
257
+ # Sort masks by score and select the best one
258
+ sorted_ind = np.argsort(scores)[::-1]
259
+ best_mask = masks[sorted_ind[0]]
260
+ best_score = scores[sorted_ind[0]]
261
 
262
+ return self._smoothen_mask(best_mask), best_score
263
 
264
  def censor_image_blur(
265
  self,