DawnC commited on
Commit
de894d3
·
verified ·
1 Parent(s): 1487b33

Upload 2 files

Browse files

Improved filter classes function

Files changed (2) hide show
  1. image_processor.py +1 -1
  2. visualization_helper.py +6 -1
image_processor.py CHANGED
@@ -141,7 +141,7 @@ class ImageProcessor:
141
  )
142
 
143
  result_image = VisualizationHelper.visualize_detection(
144
- temp_path, result, color_mapper=self.color_mapper, figsize=(12, 12), return_pil=True
145
  )
146
 
147
  result_text = EvaluationMetrics.format_detection_summary(viz_data)
 
141
  )
142
 
143
  result_image = VisualizationHelper.visualize_detection(
144
+ temp_path, result, color_mapper=self.color_mapper, figsize=(12, 12), return_pil=True, filter_classes=filter_classes
145
  )
146
 
147
  result_text = EvaluationMetrics.format_detection_summary(viz_data)
visualization_helper.py CHANGED
@@ -12,7 +12,8 @@ class VisualizationHelper:
12
  @staticmethod
13
  def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
14
  figsize: Tuple[int, int] = (12, 12),
15
- return_pil: bool = False) -> Optional[Image.Image]:
 
16
  """
17
  Visualize detection results on a single image
18
 
@@ -73,6 +74,10 @@ class VisualizationHelper:
73
  for box, cls, conf in zip(boxes, classes, confs):
74
  x1, y1, x2, y2 = box
75
  cls_id = int(cls)
 
 
 
 
76
  cls_name = names[cls_id]
77
 
78
  # Get color for this class
 
12
  @staticmethod
13
  def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
14
  figsize: Tuple[int, int] = (12, 12),
15
+ return_pil: bool = False,
16
+ filter_classes: Optional[List[int]] = None) -> Optional[Image.Image]:
17
  """
18
  Visualize detection results on a single image
19
 
 
74
  for box, cls, conf in zip(boxes, classes, confs):
75
  x1, y1, x2, y2 = box
76
  cls_id = int(cls)
77
+
78
+ if filter_classes and cls_id not in filter_classes:
79
+ continue
80
+
81
  cls_name = names[cls_id]
82
 
83
  # Get color for this class