Alessio Grancini commited on
Commit
e84b793
·
verified ·
1 Parent(s): 1091c75

Update image_segmenter.py

Browse files
Files changed (1) hide show
  1. image_segmenter.py +71 -92
image_segmenter.py CHANGED
@@ -2,124 +2,103 @@ import cv2
2
  import numpy as np
3
  from ultralytics import YOLO
4
  import random
5
- import spaces
6
- import os
7
  import torch
8
 
9
  class ImageSegmenter:
10
- def __init__(self, model_type="yolov8s-seg", device="cpu"):
11
- self.device = device
12
- self.model = YOLO(model_type).to(self.device)
 
 
 
13
  self.is_show_bounding_boxes = True
14
  self.is_show_segmentation_boundary = False
15
  self.is_show_segmentation = False
16
  self.confidence_threshold = 0.5
17
  self.cls_clr = {}
 
 
18
  self.bb_thickness = 2
19
  self.bb_clr = (255, 0, 0)
 
 
20
  self.masks = {}
21
- self.model = None
22
-
23
- # Ensure model directory exists
24
- os.makedirs('models', exist_ok=True)
25
-
26
- # Check if model file exists, if not download it
27
- model_path = os.path.join('models', f'{model_type}.pt')
28
- if not os.path.exists(model_path):
29
- print(f"Downloading {model_type} model...")
30
- self.model = YOLO(model_type)
31
- self.model.export()
32
- print("Model downloaded successfully")
33
 
34
  def get_cls_clr(self, cls_id):
35
  if cls_id in self.cls_clr:
36
  return self.cls_clr[cls_id]
 
 
37
  r = random.randint(50, 200)
38
  g = random.randint(50, 200)
39
  b = random.randint(50, 200)
40
  self.cls_clr[cls_id] = (r, g, b)
41
  return (r, g, b)
42
 
43
- @spaces.GPU
44
  def predict(self, image):
45
- try:
46
- # Initialize model if needed
47
- if self.model is None:
48
- print("Loading YOLO model...")
49
- model_path = os.path.join('models', f'{self.model_type}.pt')
50
- # Force CPU mode for YOLO initialization
51
- self.model = YOLO(model_path)
52
- self.model.to('cpu') # Explicitly move to CPU
53
- print("Model loaded successfully")
54
-
55
- # Ensure image is in correct format
56
- if isinstance(image, np.ndarray):
57
- image = image.copy()
58
- else:
59
- raise ValueError("Input image must be a numpy array")
60
-
61
- # Make prediction using CPU
62
- predictions = self.model.predict(image, device='cpu')
63
-
64
- # Process results
65
- objects_data = []
66
-
67
- if len(predictions) == 0 or not predictions[0].boxes:
68
- return image, objects_data
69
-
70
- cls_ids = predictions[0].boxes.cls.numpy() # Changed from cpu().numpy()
71
- bounding_boxes = predictions[0].boxes.xyxy.int().numpy()
72
- cls_conf = predictions[0].boxes.conf.numpy()
73
-
74
- if predictions[0].masks is not None:
75
- seg_mask_boundary = predictions[0].masks.xy
76
- seg_mask = predictions[0].masks.data.numpy() # Changed from cpu().numpy()
77
- else:
78
- seg_mask_boundary, seg_mask = [], np.array([])
79
-
80
- for id, cls in enumerate(cls_ids):
81
- if cls_conf[id] <= self.confidence_threshold:
82
- continue
83
-
84
- cls_clr = self.get_cls_clr(int(cls))
85
-
86
- if seg_mask.size > 0:
87
- self.masks[id] = seg_mask[id]
88
-
89
- if self.is_show_segmentation:
90
- alpha = 0.8
91
- colored_mask = np.expand_dims(seg_mask[id], 0).repeat(3, axis=0)
92
- colored_mask = np.moveaxis(colored_mask, 0, -1)
93
-
94
- if image.shape[:2] != seg_mask[id].shape[:2]:
95
- colored_mask = cv2.resize(colored_mask, (image.shape[1], image.shape[0]))
96
-
97
- masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=cls_clr)
98
- image_overlay = masked.filled()
99
- image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
100
-
101
- if self.is_show_bounding_boxes:
102
  (x1, y1, x2, y2) = bounding_boxes[id]
103
- cls_name = self.model.names[int(cls)]
104
  cls_confidence = cls_conf[id]
105
- disp_str = f"{cls_name} {cls_confidence:.2f}"
106
  cv2.rectangle(image, (x1, y1), (x2, y2), cls_clr, self.bb_thickness)
107
- cv2.rectangle(image, (x1, y1), (x1+len(disp_str)*9, y1+15), cls_clr, -1)
108
  cv2.putText(image, disp_str, (x1+5, y1+10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
109
 
110
- if len(seg_mask_boundary) > 0 and self.is_show_segmentation_boundary:
111
- cv2.polylines(image, [np.array(seg_mask_boundary[id], dtype=np.int32)],
112
- isClosed=True, color=cls_clr, thickness=2)
 
 
113
 
 
114
  (x1, y1, x2, y2) = bounding_boxes[id]
115
- center = (x1+(x2-x1)//2, y1+(y2-y1)//2)
116
- objects_data.append([int(cls), self.model.names[int(cls)], center,
117
- self.masks.get(id, None), cls_clr])
118
-
119
- return image, objects_data
120
-
121
- except Exception as e:
122
- print(f"Error in predict: {str(e)}")
123
- import traceback
124
- print(traceback.format_exc())
125
- raise
 
2
  import numpy as np
3
  from ultralytics import YOLO
4
  import random
 
 
5
  import torch
6
 
7
  class ImageSegmenter:
8
+ def __init__(self, model_type="yolov8s-seg") -> None:
9
+
10
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ self.model = YOLO('models/'+ model_type +'.pt')
12
+ self.model.to(self.device)
13
+
14
  self.is_show_bounding_boxes = True
15
  self.is_show_segmentation_boundary = False
16
  self.is_show_segmentation = False
17
  self.confidence_threshold = 0.5
18
  self.cls_clr = {}
19
+
20
+ # params
21
  self.bb_thickness = 2
22
  self.bb_clr = (255, 0, 0)
23
+
24
+ # variables
25
  self.masks = {}
26
+
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def get_cls_clr(self, cls_id):
29
  if cls_id in self.cls_clr:
30
  return self.cls_clr[cls_id]
31
+
32
+ # gen rand color
33
  r = random.randint(50, 200)
34
  g = random.randint(50, 200)
35
  b = random.randint(50, 200)
36
  self.cls_clr[cls_id] = (r, g, b)
37
  return (r, g, b)
38
 
 
39
  def predict(self, image):
40
+ # params
41
+ objects_data = []
42
+ image = image.copy()
43
+ predictions = self.model.predict(image)
44
+
45
+ cls_ids = predictions[0].boxes.cls.cpu().numpy()
46
+ bounding_boxes = predictions[0].boxes.xyxy.int().cpu().numpy()
47
+ cls_conf = predictions[0].boxes.conf.cpu().numpy()
48
+ # segmentation
49
+ if predictions[0].masks:
50
+ seg_mask_boundary = predictions[0].masks.xy
51
+ seg_mask = predictions[0].masks.data.cpu().numpy()
52
+ else:
53
+ seg_mask_boundary, seg_mask = [], np.array([])
54
+
55
+ for id, cls in enumerate(cls_ids):
56
+ cls_clr = self.get_cls_clr(cls)
57
+
58
+ # draw filled segmentation region
59
+ if seg_mask.any() and cls_conf[id] > self.confidence_threshold:
60
+
61
+ self.masks[id] = seg_mask[id]
62
+
63
+ if self.is_show_segmentation:
64
+ alpha = 0.8
65
+
66
+ # converting the mask from 1 channel to 3 channels
67
+ colored_mask = np.expand_dims(seg_mask[id], 0).repeat(3, axis=0)
68
+ colored_mask = np.moveaxis(colored_mask, 0, -1)
69
+
70
+ # Resize the mask to match the image size, if necessary
71
+ if image.shape[:2] != seg_mask[id].shape[:2]:
72
+ colored_mask = cv2.resize(colored_mask, (image.shape[1], image.shape[0]))
73
+
74
+ # filling the mased area with class color
75
+ masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=cls_clr)
76
+ image_overlay = masked.filled()
77
+ image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
78
+
79
+ # draw bounding box with class name and score
80
+ if self.is_show_bounding_boxes and cls_conf[id] > self.confidence_threshold:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  (x1, y1, x2, y2) = bounding_boxes[id]
82
+ cls_name = self.model.names[cls]
83
  cls_confidence = cls_conf[id]
84
+ disp_str = cls_name +' '+ str(round(cls_confidence, 2))
85
  cv2.rectangle(image, (x1, y1), (x2, y2), cls_clr, self.bb_thickness)
86
+ cv2.rectangle(image, (x1, y1), (x1+(len(disp_str)*9), y1+15), cls_clr, -1)
87
  cv2.putText(image, disp_str, (x1+5, y1+10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
88
 
89
+
90
+ # draw segmentation boundary
91
+ if len(seg_mask_boundary) and self.is_show_segmentation_boundary and cls_conf[id] > self.confidence_threshold:
92
+ cv2.polylines(image, [np.array(seg_mask_boundary[id], dtype=np.int32)], isClosed=True, color=cls_clr, thickness=2)
93
+
94
 
95
+ # object variables
96
  (x1, y1, x2, y2) = bounding_boxes[id]
97
+ center = x1+(x2-x1)//2, y1+(y2-y1)//2
98
+ objects_data.append([cls, self.model.names[cls], center, self.masks[id], cls_clr])
99
+
100
+ return image, objects_data
101
+
102
+
103
+
104
+