|
|
|
import cv2 |
|
import numpy as np |
|
import os |
|
from ultralytics import YOLO |
|
import time |
|
from typing import Tuple, Set, List |
|
|
|
def detection(path: str) -> Tuple[Set[str], str]: |
|
""" |
|
Detects and tracks objects in a video using YOLOv8 model, saving an annotated output video. |
|
|
|
Args: |
|
path (str): Path to the input video file. Supports common video formats (mp4, avi, etc.) |
|
|
|
Returns: |
|
Tuple[Set[str], str]: |
|
- Set of unique detected object labels (e.g., {'Gun', 'Knife'}) |
|
- Path to the output annotated video with detection boxes and tracking IDs |
|
|
|
Raises: |
|
FileNotFoundError: If input video doesn't exist |
|
ValueError: If video cannot be opened/processed or output dir cannot be created |
|
""" |
|
|
|
|
|
if not os.path.exists(path): |
|
raise FileNotFoundError(f"Video file not found: {path}") |
|
|
|
|
|
|
|
model_path = os.path.join(os.path.dirname(__file__), "best.pt") |
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"YOLO model file not found at: {model_path}") |
|
try: |
|
model = YOLO(model_path) |
|
class_names = model.names |
|
print(f"[INFO] YOLO model loaded from {model_path}. Class names: {class_names}") |
|
except Exception as e: |
|
raise ValueError(f"Failed to load YOLO model: {e}") |
|
|
|
|
|
|
|
input_video_name = os.path.basename(path) |
|
base_name = os.path.splitext(input_video_name)[0] |
|
|
|
safe_base_name = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in base_name) |
|
|
|
|
|
|
|
output_dir = os.path.join(os.path.dirname(__file__), "results") |
|
temp_output_name = f"{safe_base_name}_output_temp.mp4" |
|
|
|
try: |
|
os.makedirs(output_dir, exist_ok=True) |
|
if not os.path.isdir(output_dir): |
|
raise ValueError(f"Path exists but is not a directory: {output_dir}") |
|
except OSError as e: |
|
raise ValueError(f"Failed to create or access output directory '{output_dir}': {e}") |
|
|
|
temp_output_path = os.path.join(output_dir, temp_output_name) |
|
print(f"[INFO] Temporary output will be saved to: {temp_output_path}") |
|
|
|
|
|
|
|
cap = cv2.VideoCapture(path) |
|
if not cap.isOpened(): |
|
raise ValueError(f"Failed to open video file: {path}") |
|
|
|
|
|
|
|
source_fps = cap.get(cv2.CAP_PROP_FPS) |
|
output_fps = source_fps if 10 <= source_fps <= 60 else 30.0 |
|
|
|
|
|
|
|
frame_width, frame_height = 640, 640 |
|
|
|
|
|
|
|
|
|
try: |
|
out = cv2.VideoWriter( |
|
temp_output_path, |
|
cv2.VideoWriter_fourcc(*'mp4v'), |
|
output_fps, |
|
(frame_width, frame_height) |
|
) |
|
if not out.isOpened(): |
|
|
|
print("[WARNING] mp4v codec failed, trying avc1...") |
|
out = cv2.VideoWriter( |
|
temp_output_path, |
|
cv2.VideoWriter_fourcc(*'avc1'), |
|
output_fps, |
|
(frame_width, frame_height) |
|
) |
|
if not out.isOpened(): |
|
raise ValueError("Failed to initialize VideoWriter with mp4v or avc1 codec.") |
|
|
|
except Exception as e: |
|
cap.release() |
|
raise ValueError(f"Failed to create VideoWriter: {e}") |
|
|
|
|
|
|
|
detected_classes: List[str] = [] |
|
start = time.time() |
|
frame_count = 0 |
|
print(f"[INFO] Video processing started...") |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame_count += 1 |
|
|
|
resized_frame = cv2.resize(frame, (frame_width, frame_height)) |
|
|
|
try: |
|
|
|
results = model.track( |
|
source=resized_frame, |
|
conf=0.7, |
|
persist=True, |
|
verbose=False |
|
) |
|
|
|
|
|
if results and results[0] and results[0].boxes: |
|
|
|
annotated_frame = results[0].plot() |
|
|
|
|
|
for box in results[0].boxes: |
|
if box.cls is not None: |
|
cls_id = int(box.cls[0]) |
|
if 0 <= cls_id < len(class_names): |
|
detected_classes.append(class_names[cls_id]) |
|
else: |
|
print(f"[WARNING] Detected unknown class ID: {cls_id}") |
|
else: |
|
|
|
annotated_frame = resized_frame |
|
|
|
|
|
out.write(annotated_frame) |
|
|
|
except Exception as e: |
|
print(f"[ERROR] Error processing frame {frame_count}: {e}") |
|
|
|
out.write(resized_frame) |
|
|
|
|
|
|
|
end = time.time() |
|
print(f"[INFO] Video processing finished. Processed {frame_count} frames.") |
|
print(f"[INFO] Total processing time: {end - start:.2f} seconds") |
|
cap.release() |
|
out.release() |
|
cv2.destroyAllWindows() |
|
|
|
|
|
|
|
unique_detected_labels = set(detected_classes) |
|
|
|
labels_str = "_".join(sorted(list(unique_detected_labels))).replace(" ", "_") |
|
|
|
max_label_len = 50 |
|
if len(labels_str) > max_label_len: |
|
labels_str = labels_str[:max_label_len] + "_etc" |
|
if not labels_str: |
|
labels_str = "no_detections" |
|
|
|
final_output_name = f"{safe_base_name}_{labels_str}_output.mp4" |
|
final_output_path = os.path.join(output_dir, final_output_name) |
|
|
|
|
|
if os.path.exists(final_output_path): |
|
os.remove(final_output_path) |
|
|
|
try: |
|
|
|
os.rename(temp_output_path, final_output_path) |
|
print(f"[INFO] Detected object labels: {unique_detected_labels}") |
|
print(f"[INFO] Annotated video saved successfully at: {final_output_path}") |
|
except OSError as e: |
|
print(f"[ERROR] Failed to rename {temp_output_path} to {final_output_path}: {e}") |
|
|
|
if os.path.exists(temp_output_path): |
|
print(f"[WARNING] Returning path to temporary file: {temp_output_path}") |
|
return unique_detected_labels, temp_output_path |
|
else: |
|
raise ValueError(f"Output video generation failed. No output file found.") |
|
|
|
|
|
return unique_detected_labels, final_output_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|