|
|
|
|
|
from collections import defaultdict |
|
|
|
import cv2 |
|
|
|
from ultralytics import YOLO |
|
from ultralytics.utils import ASSETS_URL, DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER |
|
from ultralytics.utils.checks import check_imshow, check_requirements |
|
|
|
|
|
class BaseSolution: |
|
""" |
|
A base class for managing Ultralytics Solutions. |
|
|
|
This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking, |
|
and region initialization. |
|
|
|
Attributes: |
|
LineString (shapely.geometry.LineString): Class for creating line string geometries. |
|
Polygon (shapely.geometry.Polygon): Class for creating polygon geometries. |
|
Point (shapely.geometry.Point): Class for creating point geometries. |
|
CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs. |
|
region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest. |
|
line_width (int): Width of lines used in visualizations. |
|
model (ultralytics.YOLO): Loaded YOLO model instance. |
|
names (Dict[int, str]): Dictionary mapping class indices to class names. |
|
env_check (bool): Flag indicating whether the environment supports image display. |
|
track_history (collections.defaultdict): Dictionary to store tracking history for each object. |
|
|
|
Methods: |
|
extract_tracks: Apply object tracking and extract tracks from an input image. |
|
store_tracking_history: Store object tracking history for a given track ID and bounding box. |
|
initialize_region: Initialize the counting region and line segment based on configuration. |
|
display_output: Display the results of processing, including showing frames or saving results. |
|
|
|
Examples: |
|
>>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)]) |
|
>>> solution.initialize_region() |
|
>>> image = cv2.imread("image.jpg") |
|
>>> solution.extract_tracks(image) |
|
>>> solution.display_output(image) |
|
""" |
|
|
|
def __init__(self, IS_CLI=False, **kwargs): |
|
""" |
|
Initializes the `BaseSolution` class with configuration settings and the YOLO model for Ultralytics solutions. |
|
|
|
IS_CLI (optional): Enables CLI mode if set. |
|
""" |
|
check_requirements("shapely>=2.0.0") |
|
from shapely.geometry import LineString, Point, Polygon |
|
from shapely.prepared import prep |
|
|
|
self.LineString = LineString |
|
self.Polygon = Polygon |
|
self.Point = Point |
|
self.prep = prep |
|
self.annotator = None |
|
self.tracks = None |
|
self.track_data = None |
|
self.boxes = [] |
|
self.clss = [] |
|
self.track_ids = [] |
|
self.track_line = None |
|
self.r_s = None |
|
|
|
|
|
DEFAULT_SOL_DICT.update(kwargs) |
|
DEFAULT_CFG_DICT.update(kwargs) |
|
self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} |
|
LOGGER.info(f"Ultralytics Solutions: ✅ {DEFAULT_SOL_DICT}") |
|
|
|
self.region = self.CFG["region"] |
|
self.line_width = ( |
|
self.CFG["line_width"] if self.CFG["line_width"] is not None else 2 |
|
) |
|
|
|
|
|
if self.CFG["model"] is None: |
|
self.CFG["model"] = "yolo11n.pt" |
|
self.model = YOLO(self.CFG["model"]) |
|
self.names = self.model.names |
|
|
|
self.track_add_args = { |
|
k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"] |
|
} |
|
|
|
if IS_CLI and self.CFG["source"] is None: |
|
d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4" |
|
LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}") |
|
from ultralytics.utils.downloads import safe_download |
|
|
|
safe_download(f"{ASSETS_URL}/{d_s}") |
|
self.CFG["source"] = d_s |
|
|
|
|
|
self.env_check = check_imshow(warn=True) |
|
self.track_history = defaultdict(list) |
|
|
|
def extract_tracks(self, im0): |
|
""" |
|
Applies object tracking and extracts tracks from an input image or frame. |
|
|
|
Args: |
|
im0 (ndarray): The input image or frame. |
|
|
|
Examples: |
|
>>> solution = BaseSolution() |
|
>>> frame = cv2.imread("path/to/image.jpg") |
|
>>> solution.extract_tracks(frame) |
|
""" |
|
self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args) |
|
|
|
|
|
self.track_data = self.tracks[0].obb or self.tracks[0].boxes |
|
|
|
if self.track_data and self.track_data.id is not None: |
|
self.boxes = self.track_data.xyxy.cpu() |
|
self.clss = self.track_data.cls.cpu().tolist() |
|
self.track_ids = self.track_data.id.int().cpu().tolist() |
|
else: |
|
LOGGER.warning("WARNING ⚠️ no tracks found!") |
|
self.boxes, self.clss, self.track_ids = [], [], [] |
|
|
|
def store_tracking_history(self, track_id, box): |
|
""" |
|
Stores the tracking history of an object. |
|
|
|
This method updates the tracking history for a given object by appending the center point of its |
|
bounding box to the track line. It maintains a maximum of 30 points in the tracking history. |
|
|
|
Args: |
|
track_id (int): The unique identifier for the tracked object. |
|
box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2]. |
|
|
|
Examples: |
|
>>> solution = BaseSolution() |
|
>>> solution.store_tracking_history(1, [100, 200, 300, 400]) |
|
""" |
|
|
|
self.track_line = self.track_history[track_id] |
|
self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)) |
|
if len(self.track_line) > 30: |
|
self.track_line.pop(0) |
|
|
|
def initialize_region(self): |
|
"""Initialize the counting region and line segment based on configuration settings.""" |
|
if self.region is None: |
|
self.region = [(20, 400), (1080, 400), (1080, 360), (20, 360)] |
|
self.r_s = ( |
|
self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region) |
|
) |
|
|
|
def display_output(self, im0): |
|
""" |
|
Display the results of the processing, which could involve showing frames, printing counts, or saving results. |
|
|
|
This method is responsible for visualizing the output of the object detection and tracking process. It displays |
|
the processed frame with annotations, and allows for user interaction to close the display. |
|
|
|
Args: |
|
im0 (numpy.ndarray): The input image or frame that has been processed and annotated. |
|
|
|
Examples: |
|
>>> solution = BaseSolution() |
|
>>> frame = cv2.imread("path/to/image.jpg") |
|
>>> solution.display_output(frame) |
|
|
|
Notes: |
|
- This method will only display output if the 'show' configuration is set to True and the environment |
|
supports image display. |
|
- The display can be closed by pressing the 'q' key. |
|
""" |
|
if self.CFG.get("show") and self.env_check: |
|
cv2.imshow("Ultralytics Solutions", im0) |
|
if cv2.waitKey(1) & 0xFF == ord("q"): |
|
return |
|
|