hugohabicht01 commited on
Commit
5ea22b8
·
1 Parent(s): 3edebe6

fix to_tuple error and force sam on cuda

Browse files
Files changed (2) hide show
  1. blurnonymize.py +2 -4
  2. utils.py +4 -0
blurnonymize.py CHANGED
@@ -152,13 +152,12 @@ class ImageBlurnonymizer:
152
  self.init_sam()
153
 
154
  def init_sam(self, force=False):
155
- return
156
  # only initialize SAM if it hasn't been initialized yet
157
  if self.predictor is not None and not force:
158
  return
159
 
160
  # self.device = "cuda" if torch.cuda.is_available() else "cpu"
161
- self.device = "cpu"
162
  self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-small", device=self.device)
163
 
164
  @staticmethod
@@ -244,8 +243,7 @@ class ImageBlurnonymizer:
244
  # Ensure points are valid coordinates (e.g., non-negative)
245
  points = [[max(0, p[0]), max(0, p[1])] for p in points]
246
 
247
- # with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
248
- with torch.inference_mode(), torch.autocast(self.device):
249
  self.predictor.set_image(image)
250
  masks, scores, _ = self.predictor.predict(
251
  box=np.array(bbox), # Predictor might expect numpy array
 
152
  self.init_sam()
153
 
154
  def init_sam(self, force=False):
 
155
  # only initialize SAM if it hasn't been initialized yet
156
  if self.predictor is not None and not force:
157
  return
158
 
159
  # self.device = "cuda" if torch.cuda.is_available() else "cpu"
160
+ self.device = "cuda"
161
  self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-small", device=self.device)
162
 
163
  @staticmethod
 
243
  # Ensure points are valid coordinates (e.g., non-negative)
244
  points = [[max(0, p[0]), max(0, p[1])] for p in points]
245
 
246
+ with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
 
247
  self.predictor.set_image(image)
248
  masks, scores, _ = self.predictor.predict(
249
  box=np.array(bbox), # Predictor might expect numpy array
utils.py CHANGED
@@ -88,6 +88,10 @@ class BoundingBox(BaseModel):
88
  """Creates a BoundingBox instance from a label and a list of coordinates."""
89
  return BoundingBox(label=label, x_min=box[0], y_min=box[1], x_max=box[2], y_max=box[3])
90
 
 
 
 
 
91
  def parse_json_response(out: str) -> list[dict]:
92
  """Extracts and parses JSON content from a string.
93
 
 
88
  """Creates a BoundingBox instance from a label and a list of coordinates."""
89
  return BoundingBox(label=label, x_min=box[0], y_min=box[1], x_max=box[2], y_max=box[3])
90
 
91
+ def to_tuple(self) -> tuple[int, int, int, int]:
92
+ """Converts the BoundingBox instance to a tuple of coordinates."""
93
+ return (self.x_min, self.y_min, self.x_max, self.y_max)
94
+
95
  def parse_json_response(out: str) -> list[dict]:
96
  """Extracts and parses JSON content from a string.
97