fight-object_detection / Fight_detec_func.py
KillD00zer's picture
Update Fight_detec_func.py
f8bfcd7 verified
import tensorflow as tf
from frame_slicer import extract_video_frames
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
# Configuration
import os
MODEL_PATH = os.path.join(os.path.dirname(__file__),"final_model_2.h5")
N_FRAMES = 30
IMG_SIZE = (96, 96)
# Define RESULT_PATH relative to the script location
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
def fight_detec(video_path: str, debug: bool = True):
"""Detects fight in a video and returns the result string and raw prediction score."""
class FightDetector:
def __init__(self):
self.model = self._load_model()
def _load_model(self):
# Ensure the model path exists before loading
if not os.path.exists(MODEL_PATH):
print(f"Error: Model file not found at {MODEL_PATH}")
return None
try:
# Load model with compile=False if optimizer state isn't needed for inference
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
if debug:
print("\nModel loaded successfully. Input shape:", model.input_shape)
return model
except Exception as e:
print(f"Model loading failed: {e}")
return None
def _extract_frames(self, video_path):
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
if frames is None:
print(f"Frame extraction returned None for {video_path}")
return None
if debug:
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
if blank_frames > 0:
print(f"Warning: {blank_frames} blank frames detected")
# Save a sample frame for debugging only if debug is True
if frames.shape[0] > 0 and not np.all(frames[0] == 0): # Avoid saving blank frame
sample_frame = (frames[0] * 255).astype(np.uint8)
try:
os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists
debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
print(f"Debug frame saved to {debug_frame_path}")
except Exception as e:
print(f"Failed to save debug frame: {e}")
else:
print("Skipping debug frame save (first frame blank or no frames).")
return frames
def predict(self, video_path):
if not os.path.exists(video_path):
print(f"Error: Video not found at {video_path}")
return "Error: Video not found", None
try:
frames = self._extract_frames(video_path)
if frames is None:
return "Error: Frame extraction failed", None
if frames.shape[0] != N_FRAMES:
# Pad with last frame or zeros if not enough frames were extracted
print(f"Warning: Expected {N_FRAMES} frames, got {frames.shape[0]}. Padding...")
if frames.shape[0] == 0: # No frames at all
frames = np.zeros((N_FRAMES, *IMG_SIZE, 3), dtype=np.float32)
else: # Pad with the last available frame
padding_needed = N_FRAMES - frames.shape[0]
last_frame = frames[-1][np.newaxis, ...]
padding = np.repeat(last_frame, padding_needed, axis=0)
frames = np.concatenate((frames, padding), axis=0)
print(f"Frames padded to shape: {frames.shape}")
if np.all(frames == 0):
# Check if all frames are actually blank (can happen with padding)
print("Error: All frames are blank after processing/padding.")
return "Error: All frames are blank", None
# Perform prediction
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
# Determine result based on threshold
threshold = 0.61 # Example threshold
is_fight = prediction >= threshold
result = "FIGHT" if is_fight else "NORMAL"
# Calculate confidence (simple distance from threshold, scaled)
# Adjust scaling factor (e.g., 150) and base (e.g., 50) as needed
# Ensure confidence reflects certainty (higher for values far from threshold)
if is_fight:
confidence = min(max((prediction - threshold) * 150 + 50, 0), 100)
else:
confidence = min(max((threshold - prediction) * 150 + 50, 0), 100)
result_string = f"{result} ({confidence:.1f}% confidence)"
if debug:
print(f"Raw Prediction Score: {prediction:.4f}")
self._debug_visualization(frames, prediction, result_string, video_path)
return result_string, float(prediction) # Return string and raw score
except Exception as e:
print(f"Prediction error: {str(e)}")
# Consider logging the full traceback here in a real application
# import traceback
# print(traceback.format_exc())
return f"Prediction error: {str(e)}", None
def _debug_visualization(self, frames, score, result, video_path):
# This function will only run if debug=True is passed to fight_detec
print(f"\n--- Debug Visualization ---")
print(f"Prediction Score: {score:.4f}")
print(f"Decision: {result}")
# Avoid plotting if matplotlib is not available or causes issues in deployment
try:
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
num_frames_to_show = min(10, len(frames))
for i in range(num_frames_to_show):
plt.subplot(2, 5, i+1)
# Ensure frame values are valid for imshow (0-1 or 0-255)
img_display = frames[i]
if np.max(img_display) <= 1.0: # Assuming normalized float [0,1]
img_display = (img_display * 255).astype(np.uint8)
else: # Assuming it might already be uint8 [0,255]
img_display = img_display.astype(np.uint8)
plt.imshow(img_display)
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}") # Use original frame for mean
plt.axis('off')
plt.suptitle(f"Video: {os.path.basename(video_path)}\nPrediction: {result} (Raw Score: {score:.4f})")
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
# Save the visualization
os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists again
base_name = os.path.splitext(os.path.basename(video_path))[0]
save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
plt.savefig(save_path)
plt.close() # Close the plot to free memory
print(f"Debug visualization saved to: {save_path}")
except ImportError:
print("Matplotlib not found. Skipping debug visualization plot.")
except Exception as e:
print(f"Error during debug visualization: {e}")
print("--- End Debug Visualization ---")
# --- Main function logic ---
detector = FightDetector()
if detector.model is None:
# Model loading failed, return error
return "Error: Model loading failed", None
# Call the predict method
result_str, prediction_score = detector.predict(video_path)
return result_str, prediction_score
# # Example usage (commented out for library use)
# if __name__ == "__main__":
# # Example of how to call the function
# test_video = input("Enter the local path to the video file: ").strip('"')
# if os.path.exists(test_video):
# print(f"[INFO] Processing video: {test_video}")
# result, score = fight_detec(test_video, debug=True) # Enable debug for local testing
# print(f"\nFinal Result: {result}")
# if score is not None:
# print(f"Raw Score: {score:.4f}")
# else:
# print(f"Error: File not found - {test_video}")