Spaces:
Running
on
Zero
Running
on
Zero
hugohabicht01
commited on
Commit
·
d528373
1
Parent(s):
a894bc6
download sam from huggingface
Browse files- blurnonymize.py +14 -44
blurnonymize.py
CHANGED
@@ -7,7 +7,6 @@ import matplotlib.patches as patches
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
-
from sam2.build_sam import build_sam2
|
11 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
12 |
from utils import *
|
13 |
|
@@ -160,38 +159,8 @@ class ImageBlurnonymizer:
|
|
160 |
if self.predictor is not None and not force:
|
161 |
return
|
162 |
|
163 |
-
self.download_weights()
|
164 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
165 |
-
|
166 |
-
self.predictor = SAM2ImagePredictor(sam)
|
167 |
-
|
168 |
-
def download_weights(self):
|
169 |
-
# TODO: check whether these files already exist, if not, download them
|
170 |
-
# files names are in self.checkpoint_name and self.model_cfg_name
|
171 |
-
|
172 |
-
print(f"Current working directory: {os.getcwd()}")
|
173 |
-
print(f"Files in current dir: {os.listdir()}")
|
174 |
-
print(f"Files in root dir: {os.listdir('/')}")
|
175 |
-
|
176 |
-
print(f"Files in /data: {os.listdir('/data')}")
|
177 |
-
|
178 |
-
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt"
|
179 |
-
cfg_url = "https://raw.githubusercontent.com/facebookresearch/sam2/refs/heads/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml"
|
180 |
-
|
181 |
-
config_dir = "/data/configs/sam2.1"
|
182 |
-
|
183 |
-
if not os.path.exists(config_dir):
|
184 |
-
os.makedirs(config_dir)
|
185 |
-
|
186 |
-
if not os.path.exists(self.checkpoint_name):
|
187 |
-
print(
|
188 |
-
f"Downloading checkpoint from {checkpoint_url} to {self.checkpoint_name}"
|
189 |
-
)
|
190 |
-
torch.hub.download_url_to_file(checkpoint_url, self.checkpoint_name)
|
191 |
-
|
192 |
-
if not os.path.exists(self.model_cfg_name):
|
193 |
-
print(f"Downloading config from {cfg_url} to {self.model_cfg_name}")
|
194 |
-
torch.hub.download_url_to_file(cfg_url, self.model_cfg_name)
|
195 |
|
196 |
@staticmethod
|
197 |
def _smoothen_mask(mask: np.ndarray):
|
@@ -276,20 +245,21 @@ class ImageBlurnonymizer:
|
|
276 |
# Ensure points are valid coordinates (e.g., non-negative)
|
277 |
points = [[max(0, p[0]), max(0, p[1])] for p in points]
|
278 |
|
279 |
-
self.
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
|
292 |
-
|
293 |
|
294 |
def censor_image_blur(
|
295 |
self,
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import numpy as np
|
9 |
import torch
|
|
|
10 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
11 |
from utils import *
|
12 |
|
|
|
159 |
if self.predictor is not None and not force:
|
160 |
return
|
161 |
|
|
|
162 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
163 |
+
self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
@staticmethod
|
166 |
def _smoothen_mask(mask: np.ndarray):
|
|
|
245 |
# Ensure points are valid coordinates (e.g., non-negative)
|
246 |
points = [[max(0, p[0]), max(0, p[1])] for p in points]
|
247 |
|
248 |
+
with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
|
249 |
+
self.predictor.set_image(image)
|
250 |
+
masks, scores, _ = self.predictor.predict(
|
251 |
+
box=np.array(bbox), # Predictor might expect numpy array
|
252 |
+
point_coords=np.array(points),
|
253 |
+
point_labels=np.ones(len(points)), # Label 1 for inclusion
|
254 |
+
multimask_output=True,
|
255 |
+
)
|
256 |
|
257 |
+
# Sort masks by score and select the best one
|
258 |
+
sorted_ind = np.argsort(scores)[::-1]
|
259 |
+
best_mask = masks[sorted_ind[0]]
|
260 |
+
best_score = scores[sorted_ind[0]]
|
261 |
|
262 |
+
return self._smoothen_mask(best_mask), best_score
|
263 |
|
264 |
def censor_image_blur(
|
265 |
self,
|