Alessio Grancini commited on
Commit
4f0cfe1
·
verified ·
1 Parent(s): e38543b

Update image_segmenter.py

Browse files
Files changed (1) hide show
  1. image_segmenter.py +18 -16
image_segmenter.py CHANGED
@@ -3,14 +3,12 @@ import numpy as np
3
  from ultralytics import YOLO
4
  import random
5
  import torch
 
6
 
7
  class ImageSegmenter:
8
  def __init__(self, model_type="yolov8s-seg") -> None:
9
-
10
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
- self.model = YOLO('models/'+ model_type +'.pt')
12
- self.model.to(self.device)
13
-
14
  self.is_show_bounding_boxes = True
15
  self.is_show_segmentation_boundary = False
16
  self.is_show_segmentation = False
@@ -23,7 +21,9 @@ class ImageSegmenter:
23
 
24
  # variables
25
  self.masks = {}
26
-
 
 
27
 
28
  def get_cls_clr(self, cls_id):
29
  if cls_id in self.cls_clr:
@@ -36,15 +36,24 @@ class ImageSegmenter:
36
  self.cls_clr[cls_id] = (r, g, b)
37
  return (r, g, b)
38
 
 
39
  def predict(self, image):
 
 
 
 
 
40
  # params
41
  objects_data = []
42
  image = image.copy()
 
 
43
  predictions = self.model.predict(image)
44
 
45
  cls_ids = predictions[0].boxes.cls.cpu().numpy()
46
  bounding_boxes = predictions[0].boxes.xyxy.int().cpu().numpy()
47
  cls_conf = predictions[0].boxes.conf.cpu().numpy()
 
48
  # segmentation
49
  if predictions[0].masks:
50
  seg_mask_boundary = predictions[0].masks.xy
@@ -56,8 +65,7 @@ class ImageSegmenter:
56
  cls_clr = self.get_cls_clr(cls)
57
 
58
  # draw filled segmentation region
59
- if seg_mask.any() and cls_conf[id] > self.confidence_threshold:
60
-
61
  self.masks[id] = seg_mask[id]
62
 
63
  if self.is_show_segmentation:
@@ -71,7 +79,7 @@ class ImageSegmenter:
71
  if image.shape[:2] != seg_mask[id].shape[:2]:
72
  colored_mask = cv2.resize(colored_mask, (image.shape[1], image.shape[0]))
73
 
74
- # filling the mased area with class color
75
  masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=cls_clr)
76
  image_overlay = masked.filled()
77
  image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
@@ -86,19 +94,13 @@ class ImageSegmenter:
86
  cv2.rectangle(image, (x1, y1), (x1+(len(disp_str)*9), y1+15), cls_clr, -1)
87
  cv2.putText(image, disp_str, (x1+5, y1+10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
88
 
89
-
90
  # draw segmentation boundary
91
  if len(seg_mask_boundary) and self.is_show_segmentation_boundary and cls_conf[id] > self.confidence_threshold:
92
  cv2.polylines(image, [np.array(seg_mask_boundary[id], dtype=np.int32)], isClosed=True, color=cls_clr, thickness=2)
93
 
94
-
95
  # object variables
96
  (x1, y1, x2, y2) = bounding_boxes[id]
97
  center = x1+(x2-x1)//2, y1+(y2-y1)//2
98
  objects_data.append([cls, self.model.names[cls], center, self.masks[id], cls_clr])
99
 
100
- return image, objects_data
101
-
102
-
103
-
104
-
 
3
  from ultralytics import YOLO
4
  import random
5
  import torch
6
+ import spaces
7
 
8
  class ImageSegmenter:
9
  def __init__(self, model_type="yolov8s-seg") -> None:
10
+ self.model_type = model_type
11
+ self.device = 'cuda' # ZeroGPU will always use CUDA
 
 
 
12
  self.is_show_bounding_boxes = True
13
  self.is_show_segmentation_boundary = False
14
  self.is_show_segmentation = False
 
21
 
22
  # variables
23
  self.masks = {}
24
+
25
+ # Model will be loaded in predict to work with ZeroGPU
26
+ self.model = None
27
 
28
  def get_cls_clr(self, cls_id):
29
  if cls_id in self.cls_clr:
 
36
  self.cls_clr[cls_id] = (r, g, b)
37
  return (r, g, b)
38
 
39
+ @spaces.GPU(duration=30) # Adjust duration based on your needs
40
  def predict(self, image):
41
+ # Load model if not loaded (will happen on first prediction)
42
+ if self.model is None:
43
+ self.model = YOLO('models/' + self.model_type + '.pt')
44
+ self.model.to(self.device)
45
+
46
  # params
47
  objects_data = []
48
  image = image.copy()
49
+
50
+ # Run prediction
51
  predictions = self.model.predict(image)
52
 
53
  cls_ids = predictions[0].boxes.cls.cpu().numpy()
54
  bounding_boxes = predictions[0].boxes.xyxy.int().cpu().numpy()
55
  cls_conf = predictions[0].boxes.conf.cpu().numpy()
56
+
57
  # segmentation
58
  if predictions[0].masks:
59
  seg_mask_boundary = predictions[0].masks.xy
 
65
  cls_clr = self.get_cls_clr(cls)
66
 
67
  # draw filled segmentation region
68
+ if seg_mask.any() and cls_conf[id] > self.confidence_threshold:
 
69
  self.masks[id] = seg_mask[id]
70
 
71
  if self.is_show_segmentation:
 
79
  if image.shape[:2] != seg_mask[id].shape[:2]:
80
  colored_mask = cv2.resize(colored_mask, (image.shape[1], image.shape[0]))
81
 
82
+ # filling the masked area with class color
83
  masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=cls_clr)
84
  image_overlay = masked.filled()
85
  image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
 
94
  cv2.rectangle(image, (x1, y1), (x1+(len(disp_str)*9), y1+15), cls_clr, -1)
95
  cv2.putText(image, disp_str, (x1+5, y1+10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
96
 
 
97
  # draw segmentation boundary
98
  if len(seg_mask_boundary) and self.is_show_segmentation_boundary and cls_conf[id] > self.confidence_threshold:
99
  cv2.polylines(image, [np.array(seg_mask_boundary[id], dtype=np.int32)], isClosed=True, color=cls_clr, thickness=2)
100
 
 
101
  # object variables
102
  (x1, y1, x2, y2) = bounding_boxes[id]
103
  center = x1+(x2-x1)//2, y1+(y2-y1)//2
104
  objects_data.append([cls, self.model.names[cls], center, self.masks[id], cls_clr])
105
 
106
+ return image, objects_data