fight-object_detection / objec_detect_yolo.py
KillD00zer's picture
Update objec_detect_yolo.py
b2f3c2c verified
import cv2
import numpy as np
import os
from ultralytics import YOLO
import time
from typing import Tuple, Set, List
def detection(path: str) -> Tuple[Set[str], str]:
"""
Detects and tracks objects in a video using YOLOv8 model, saving an annotated output video.
Args:
path (str): Path to the input video file. Supports common video formats (mp4, avi, etc.)
Returns:
Tuple[Set[str], str]:
- Set of unique detected object labels (e.g., {'Gun', 'Knife'})
- Path to the output annotated video with detection boxes and tracking IDs
Raises:
FileNotFoundError: If input video doesn't exist
ValueError: If video cannot be opened/processed or output dir cannot be created
"""
# Validate input file exists
if not os.path.exists(path):
raise FileNotFoundError(f"Video file not found: {path}")
# --- Model Loading ---
# Construct path relative to this script file
model_path = os.path.join(os.path.dirname(__file__), "best.pt")
if not os.path.exists(model_path):
raise FileNotFoundError(f"YOLO model file not found at: {model_path}")
try:
model = YOLO(model_path)
class_names = model.names # Get class label mappings
print(f"[INFO] YOLO model loaded from {model_path}. Class names: {class_names}")
except Exception as e:
raise ValueError(f"Failed to load YOLO model: {e}")
# --- Output Path Setup ---
input_video_name = os.path.basename(path)
base_name = os.path.splitext(input_video_name)[0]
# Sanitize basename to prevent issues with weird characters in filenames
safe_base_name = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in base_name)
# Define output directory relative to this script
# In HF Spaces, this will be inside the container's file system
output_dir = os.path.join(os.path.dirname(__file__), "results")
temp_output_name = f"{safe_base_name}_output_temp.mp4"
try:
os.makedirs(output_dir, exist_ok=True) # Create output dir if needed
if not os.path.isdir(output_dir):
raise ValueError(f"Path exists but is not a directory: {output_dir}")
except OSError as e:
raise ValueError(f"Failed to create or access output directory '{output_dir}': {e}")
temp_output_path = os.path.join(output_dir, temp_output_name)
print(f"[INFO] Temporary output will be saved to: {temp_output_path}")
# --- Video Processing Setup ---
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Failed to open video file: {path}")
# Get video properties for output writer
# Use source FPS if available and reasonable, otherwise default to 30
source_fps = cap.get(cv2.CAP_PROP_FPS)
output_fps = source_fps if 10 <= source_fps <= 60 else 30.0
# Process at a fixed resolution for consistency or use source resolution
# Using fixed 640x640 as potentially used during training/fine-tuning
frame_width, frame_height = 640, 640
# OR use source resolution (might require adjusting YOLO parameters if model expects specific size)
# frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
try:
out = cv2.VideoWriter(
temp_output_path,
cv2.VideoWriter_fourcc(*'mp4v'), # Use MP4 codec
output_fps,
(frame_width, frame_height)
)
if not out.isOpened():
# Attempt alternative codec if mp4v fails (less common)
print("[WARNING] mp4v codec failed, trying avc1...")
out = cv2.VideoWriter(
temp_output_path,
cv2.VideoWriter_fourcc(*'avc1'),
output_fps,
(frame_width, frame_height)
)
if not out.isOpened():
raise ValueError("Failed to initialize VideoWriter with mp4v or avc1 codec.")
except Exception as e:
cap.release() # Release capture device before raising
raise ValueError(f"Failed to create VideoWriter: {e}")
# --- Main Processing Loop ---
detected_classes: List[str] = [] # Track detected object class names
start = time.time()
frame_count = 0
print(f"[INFO] Video processing started...")
while True:
ret, frame = cap.read()
if not ret: # End of video or read error
break
frame_count += 1
# Resize frame BEFORE passing to model
resized_frame = cv2.resize(frame, (frame_width, frame_height))
try:
# Run YOLOv8 detection and tracking on the resized frame
results = model.track(
source=resized_frame, # Use resized frame
conf=0.7, # Confidence threshold
persist=True, # Maintain track IDs across frames
verbose=False # Suppress Ultralytics console output per frame
)
# Check if results are valid and contain boxes
if results and results[0] and results[0].boxes:
# Annotate the RESIZED frame with bounding boxes and track IDs
annotated_frame = results[0].plot() # plot() draws on the source image
# Record detected class names for this frame
for box in results[0].boxes:
if box.cls is not None: # Check if class ID is present
cls_id = int(box.cls[0]) # Get class index
if 0 <= cls_id < len(class_names):
detected_classes.append(class_names[cls_id])
else:
print(f"[WARNING] Detected unknown class ID: {cls_id}")
else:
# If no detections, use the original resized frame for the output video
annotated_frame = resized_frame
# Write the (potentially annotated) frame to the output video
out.write(annotated_frame)
except Exception as e:
print(f"[ERROR] Error processing frame {frame_count}: {e}")
# Write the unannotated frame to keep video timing consistent
out.write(resized_frame)
# --- Clean Up ---
end = time.time()
print(f"[INFO] Video processing finished. Processed {frame_count} frames.")
print(f"[INFO] Total processing time: {end - start:.2f} seconds")
cap.release()
out.release()
cv2.destroyAllWindows() # Close any OpenCV windows if they were opened
# --- Final Output Renaming ---
unique_detected_labels = set(detected_classes)
# Create a short string from labels for the filename
labels_str = "_".join(sorted(list(unique_detected_labels))).replace(" ", "_")
# Limit length to avoid overly long filenames
max_label_len = 50
if len(labels_str) > max_label_len:
labels_str = labels_str[:max_label_len] + "_etc"
if not labels_str: # Handle case where nothing was detected
labels_str = "no_detections"
final_output_name = f"{safe_base_name}_{labels_str}_output.mp4"
final_output_path = os.path.join(output_dir, final_output_name)
# Ensure final path doesn't already exist (rename might fail otherwise)
if os.path.exists(final_output_path):
os.remove(final_output_path)
try:
# Rename the temporary file to the final name
os.rename(temp_output_path, final_output_path)
print(f"[INFO] Detected object labels: {unique_detected_labels}")
print(f"[INFO] Annotated video saved successfully at: {final_output_path}")
except OSError as e:
print(f"[ERROR] Failed to rename {temp_output_path} to {final_output_path}: {e}")
# Fallback: return the temp path if rename fails but file exists
if os.path.exists(temp_output_path):
print(f"[WARNING] Returning path to temporary file: {temp_output_path}")
return unique_detected_labels, temp_output_path
else:
raise ValueError(f"Output video generation failed. No output file found.")
return unique_detected_labels, final_output_path
# # Example usage (commented out for library use)
# if __name__ == "__main__":
# test_video = input("Enter the local path to the video file: ").strip('"')
# if os.path.exists(test_video):
# try:
# print(f"[INFO] Processing video: {test_video}")
# labels, out_path = detection(test_video)
# print(f"\nDetection Complete.")
# print(f"Detected unique labels: {labels}")
# print(f"Output video saved to: {out_path}")
# except (FileNotFoundError, ValueError, Exception) as e:
# print(f"\nAn error occurred: {e}")
# else:
# print(f"Error: Input video file not found - {test_video}")