Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,981 Bytes
611206a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from ultralytics import YOLO
from typing import Any, List, Dict, Optional
import torch
import numpy as np
import os
class DetectionModel:
"""Core detection model class for object detection using YOLOv8"""
# Model information dictionary
MODEL_INFO = {
"yolov8n.pt": {
"name": "YOLOv8n (Nano)",
"description": "Fastest model with smallest size (3.2M parameters). Best for speed-critical applications.",
"size_mb": 6,
"inference_speed": "Very Fast"
},
"yolov8m.pt": {
"name": "YOLOv8m (Medium)",
"description": "Balanced model with good accuracy-speed tradeoff (25.9M parameters). Recommended for general use.",
"size_mb": 25,
"inference_speed": "Medium"
},
"yolov8x.pt": {
"name": "YOLOv8x (XLarge)",
"description": "Most accurate but slower model (68.2M parameters). Best for accuracy-critical applications.",
"size_mb": 68,
"inference_speed": "Slower"
}
}
def __init__(self, model_name: str = 'yolov8m.pt', confidence: float = 0.25, iou: float = 0.45):
"""
Initialize the detection model
Args:
model_name: Model name or path, default is yolov8m.pt
confidence: Confidence threshold, default is 0.25
iou: IoU threshold for non-maximum suppression, default is 0.45
"""
self.model_name = model_name
self.confidence = confidence
self.iou = iou
self.model = None
self.class_names = {}
self.is_model_loaded = False
# Load model on initialization
self._load_model()
def _load_model(self):
"""Load the YOLO model"""
try:
print(f"Loading model: {self.model_name}")
self.model = YOLO(self.model_name)
self.class_names = self.model.names
self.is_model_loaded = True
print(f"Successfully loaded model: {self.model_name}")
print(f"Number of classes the model can recognize: {len(self.class_names)}")
except Exception as e:
print(f"Error occurred when loading the model: {e}")
self.is_model_loaded = False
def change_model(self, new_model_name: str) -> bool:
"""
Change the currently loaded model
Args:
new_model_name: Name of the new model to load
Returns:
bool: True if model changed successfully, False otherwise
"""
if self.model_name == new_model_name and self.is_model_loaded:
print(f"Model {new_model_name} is already loaded")
return True
print(f"Changing model from {self.model_name} to {new_model_name}")
# Unload current model to free memory
if self.model is not None:
del self.model
self.model = None
# Clean GPU memory if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Update model name and load new model
self.model_name = new_model_name
self._load_model()
return self.is_model_loaded
def reload_model(self):
"""Reload the model (useful for changing model or after error)"""
if self.model is not None:
del self.model
self.model = None
# Clean GPU memory if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._load_model()
def detect(self, image_input: Any) -> Optional[Any]:
"""
Perform object detection on a single image
Args:
image_input: Image path (str), PIL Image, or numpy array
Returns:
Detection result object or None if error occurred
"""
if self.model is None or not self.is_model_loaded:
print("Model not found or not loaded. Attempting to reload...")
self._load_model()
if self.model is None or not self.is_model_loaded:
print("Failed to load model. Cannot perform detection.")
return None
try:
results = self.model(image_input, conf=self.confidence, iou=self.iou)
return results[0]
except Exception as e:
print(f"Error occurred during detection: {e}")
return None
def get_class_names(self, class_id: int) -> str:
"""Get class name for a given class ID"""
return self.class_names.get(class_id, "Unknown Class")
def get_supported_classes(self) -> Dict[int, str]:
"""Get all supported classes as a dictionary of {id: class_name}"""
return self.class_names
@classmethod
def get_available_models(cls) -> List[Dict]:
"""
Get list of available models with their information
Returns:
List of dictionaries containing model information
"""
models = []
for model_file, info in cls.MODEL_INFO.items():
models.append({
"model_file": model_file,
"name": info["name"],
"description": info["description"],
"size_mb": info["size_mb"],
"inference_speed": info["inference_speed"]
})
return models
@classmethod
def get_model_description(cls, model_name: str) -> str:
"""Get description for a specific model"""
if model_name in cls.MODEL_INFO:
info = cls.MODEL_INFO[model_name]
return f"{info['name']}: {info['description']} (Size: ~{info['size_mb']}MB, Speed: {info['inference_speed']})"
return "Model information not available"
|