File size: 4,300 Bytes
da07a7d adc4b03 da07a7d adc4b03 da07a7d adc4b03 da07a7d adc4b03 da07a7d adc4b03 da07a7d adc4b03 da07a7d adc4b03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import tensorflow as tf
import os
import numpy as np
import cv2
import logging
from datetime import datetime
import matplotlib.pyplot as plt
from frame_slicer import extract_video_frames # Make sure this module is in PYTHONPATH
# ----------------- Configuration ----------------- #
MODEL_PATH = os.path.join(os.path.dirname(__file__), "training_output", "final_model_2.h5")
N_FRAMES = 30
IMG_SIZE = (96, 96)
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
FIGHT_THRESHOLD = 0.61
os.makedirs(RESULT_PATH, exist_ok=True)
# ----------------- Logging Setup ----------------- #
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
# ----------------- Main Detector Class ----------------- #
class FightDetector:
def __init__(self):
self.model = self._load_model()
def _load_model(self):
try:
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
logger.info(f"Model loaded successfully. Input shape: {model.input_shape}")
return model
except Exception as e:
logger.error(f"Model loading failed: {e}")
return None
def _extract_frames(self, video_path):
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
if frames is None:
return None
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
if blank_frames > 0:
logger.warning(f"{blank_frames} blank frames detected.")
sample_frame = (frames[0] * 255).astype(np.uint8)
debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
logger.info(f"Debug frame saved to: {debug_frame_path}")
return frames
def predict(self, video_path):
if not os.path.exists(video_path):
return "Error: Video not found", None
try:
frames = self._extract_frames(video_path)
if frames is None:
return "Error: Frame extraction failed", None
if frames.shape[0] != N_FRAMES:
return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
if np.all(frames == 0):
return "Error: All frames are blank", None
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
result = "FIGHT" if prediction >= FIGHT_THRESHOLD else "NORMAL"
confidence = min(max(abs(prediction - FIGHT_THRESHOLD) * 150 + 50, 0), 100)
self._debug_visualization(frames, prediction, result, video_path)
return f"{result} ({confidence:.1f}% confidence)", prediction
except Exception as e:
return f"Prediction error: {str(e)}", None
def _debug_visualization(self, frames, score, result, video_path):
logger.info(f"Prediction Score: {score:.4f}")
logger.info(f"Decision: {result}")
plt.figure(figsize=(15, 5))
for i in range(min(10, len(frames))):
plt.subplot(2, 5, i + 1)
plt.imshow(frames[i])
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
plt.axis('off')
plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
plt.tight_layout()
base_name = os.path.splitext(os.path.basename(video_path))[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = os.path.join(RESULT_PATH, f"{base_name}_result_{timestamp}.png")
plt.savefig(save_path)
plt.close()
logger.info(f"Visualization saved to: {save_path}")
# ----------------- External Interface ----------------- #
def fight_detec(video_path: str, debug: bool = True):
detector = FightDetector()
if detector.model is None:
return "Error: Model loading failed", None
return detector.predict(video_path)
# ----------------- Optional CLI Entry ----------------- #
if __name__ == "__main__":
path0 = input("Enter the local path to the video file to detect fight: ")
path = path0.strip('"') # Remove extra quotes if copied from Windows
logger.info(f"Loading video: {path}")
result, score = fight_detec(path)
logger.info(f"Result: {result}")
|