fight-object_detection / Fight_detec_func.py
KillD00zer's picture
Update Fight_detec_func.py
adc4b03 verified
raw
history blame
4.3 kB
import tensorflow as tf
import os
import numpy as np
import cv2
import logging
from datetime import datetime
import matplotlib.pyplot as plt
from frame_slicer import extract_video_frames # Make sure this module is in PYTHONPATH
# ----------------- Configuration ----------------- #
MODEL_PATH = os.path.join(os.path.dirname(__file__), "training_output", "final_model_2.h5")
N_FRAMES = 30
IMG_SIZE = (96, 96)
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
FIGHT_THRESHOLD = 0.61
os.makedirs(RESULT_PATH, exist_ok=True)
# ----------------- Logging Setup ----------------- #
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
# ----------------- Main Detector Class ----------------- #
class FightDetector:
def __init__(self):
self.model = self._load_model()
def _load_model(self):
try:
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
logger.info(f"Model loaded successfully. Input shape: {model.input_shape}")
return model
except Exception as e:
logger.error(f"Model loading failed: {e}")
return None
def _extract_frames(self, video_path):
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
if frames is None:
return None
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
if blank_frames > 0:
logger.warning(f"{blank_frames} blank frames detected.")
sample_frame = (frames[0] * 255).astype(np.uint8)
debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
logger.info(f"Debug frame saved to: {debug_frame_path}")
return frames
def predict(self, video_path):
if not os.path.exists(video_path):
return "Error: Video not found", None
try:
frames = self._extract_frames(video_path)
if frames is None:
return "Error: Frame extraction failed", None
if frames.shape[0] != N_FRAMES:
return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
if np.all(frames == 0):
return "Error: All frames are blank", None
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
result = "FIGHT" if prediction >= FIGHT_THRESHOLD else "NORMAL"
confidence = min(max(abs(prediction - FIGHT_THRESHOLD) * 150 + 50, 0), 100)
self._debug_visualization(frames, prediction, result, video_path)
return f"{result} ({confidence:.1f}% confidence)", prediction
except Exception as e:
return f"Prediction error: {str(e)}", None
def _debug_visualization(self, frames, score, result, video_path):
logger.info(f"Prediction Score: {score:.4f}")
logger.info(f"Decision: {result}")
plt.figure(figsize=(15, 5))
for i in range(min(10, len(frames))):
plt.subplot(2, 5, i + 1)
plt.imshow(frames[i])
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
plt.axis('off')
plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
plt.tight_layout()
base_name = os.path.splitext(os.path.basename(video_path))[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = os.path.join(RESULT_PATH, f"{base_name}_result_{timestamp}.png")
plt.savefig(save_path)
plt.close()
logger.info(f"Visualization saved to: {save_path}")
# ----------------- External Interface ----------------- #
def fight_detec(video_path: str, debug: bool = True):
detector = FightDetector()
if detector.model is None:
return "Error: Model loading failed", None
return detector.predict(video_path)
# ----------------- Optional CLI Entry ----------------- #
if __name__ == "__main__":
path0 = input("Enter the local path to the video file to detect fight: ")
path = path0.strip('"') # Remove extra quotes if copied from Windows
logger.info(f"Loading video: {path}")
result, score = fight_detec(path)
logger.info(f"Result: {result}")