VisionScout / detection_model.py
DawnC's picture
Upload 6 files
611206a verified
raw
history blame
5.98 kB
from ultralytics import YOLO
from typing import Any, List, Dict, Optional
import torch
import numpy as np
import os
class DetectionModel:
"""Core detection model class for object detection using YOLOv8"""
# Model information dictionary
MODEL_INFO = {
"yolov8n.pt": {
"name": "YOLOv8n (Nano)",
"description": "Fastest model with smallest size (3.2M parameters). Best for speed-critical applications.",
"size_mb": 6,
"inference_speed": "Very Fast"
},
"yolov8m.pt": {
"name": "YOLOv8m (Medium)",
"description": "Balanced model with good accuracy-speed tradeoff (25.9M parameters). Recommended for general use.",
"size_mb": 25,
"inference_speed": "Medium"
},
"yolov8x.pt": {
"name": "YOLOv8x (XLarge)",
"description": "Most accurate but slower model (68.2M parameters). Best for accuracy-critical applications.",
"size_mb": 68,
"inference_speed": "Slower"
}
}
def __init__(self, model_name: str = 'yolov8m.pt', confidence: float = 0.25, iou: float = 0.45):
"""
Initialize the detection model
Args:
model_name: Model name or path, default is yolov8m.pt
confidence: Confidence threshold, default is 0.25
iou: IoU threshold for non-maximum suppression, default is 0.45
"""
self.model_name = model_name
self.confidence = confidence
self.iou = iou
self.model = None
self.class_names = {}
self.is_model_loaded = False
# Load model on initialization
self._load_model()
def _load_model(self):
"""Load the YOLO model"""
try:
print(f"Loading model: {self.model_name}")
self.model = YOLO(self.model_name)
self.class_names = self.model.names
self.is_model_loaded = True
print(f"Successfully loaded model: {self.model_name}")
print(f"Number of classes the model can recognize: {len(self.class_names)}")
except Exception as e:
print(f"Error occurred when loading the model: {e}")
self.is_model_loaded = False
def change_model(self, new_model_name: str) -> bool:
"""
Change the currently loaded model
Args:
new_model_name: Name of the new model to load
Returns:
bool: True if model changed successfully, False otherwise
"""
if self.model_name == new_model_name and self.is_model_loaded:
print(f"Model {new_model_name} is already loaded")
return True
print(f"Changing model from {self.model_name} to {new_model_name}")
# Unload current model to free memory
if self.model is not None:
del self.model
self.model = None
# Clean GPU memory if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Update model name and load new model
self.model_name = new_model_name
self._load_model()
return self.is_model_loaded
def reload_model(self):
"""Reload the model (useful for changing model or after error)"""
if self.model is not None:
del self.model
self.model = None
# Clean GPU memory if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._load_model()
def detect(self, image_input: Any) -> Optional[Any]:
"""
Perform object detection on a single image
Args:
image_input: Image path (str), PIL Image, or numpy array
Returns:
Detection result object or None if error occurred
"""
if self.model is None or not self.is_model_loaded:
print("Model not found or not loaded. Attempting to reload...")
self._load_model()
if self.model is None or not self.is_model_loaded:
print("Failed to load model. Cannot perform detection.")
return None
try:
results = self.model(image_input, conf=self.confidence, iou=self.iou)
return results[0]
except Exception as e:
print(f"Error occurred during detection: {e}")
return None
def get_class_names(self, class_id: int) -> str:
"""Get class name for a given class ID"""
return self.class_names.get(class_id, "Unknown Class")
def get_supported_classes(self) -> Dict[int, str]:
"""Get all supported classes as a dictionary of {id: class_name}"""
return self.class_names
@classmethod
def get_available_models(cls) -> List[Dict]:
"""
Get list of available models with their information
Returns:
List of dictionaries containing model information
"""
models = []
for model_file, info in cls.MODEL_INFO.items():
models.append({
"model_file": model_file,
"name": info["name"],
"description": info["description"],
"size_mb": info["size_mb"],
"inference_speed": info["inference_speed"]
})
return models
@classmethod
def get_model_description(cls, model_name: str) -> str:
"""Get description for a specific model"""
if model_name in cls.MODEL_INFO:
info = cls.MODEL_INFO[model_name]
return f"{info['name']}: {info['description']} (Size: ~{info['size_mb']}MB, Speed: {info['inference_speed']})"
return "Model information not available"