File size: 4,192 Bytes
da07a7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import tensorflow as tf
from frame_slicer import extract_video_frames
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
# Configuration
import os
MODEL_PATH = os.path.join(os.path.dirname(__file__), "trainnig_output", "final_model_2.h5")
N_FRAMES = 30
IMG_SIZE = (96, 96)
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results") # Will be created if doesn't exist
def fight_detec(video_path: str, debug: bool = True):
"""Detects fight in a video and returns the result and confidence score."""
class FightDetector:
def __init__(self):
self.model = self._load_model()
def _load_model(self):
try:
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
if debug:
print("\nModel loaded successfully. Input shape:", model.input_shape)
return model
except Exception as e:
print(f"Model loading failed: {e}")
return None
def _extract_frames(self, video_path):
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
if frames is None:
return None
if debug:
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
if blank_frames > 0:
print(f"Warning: {blank_frames} blank frames detected")
sample_frame = (frames[0] * 255).astype(np.uint8)
os.makedirs(RESULT_PATH, exist_ok=True)
cv2.imwrite(os.path.join(RESULT_PATH, 'debug_frame.jpg'),
cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
return frames
def predict(self, video_path):
if not os.path.exists(video_path):
return "Error: Video not found", None
try:
frames = self._extract_frames(video_path)
if frames is None:
return "Error: Frame extraction failed", None
if frames.shape[0] != N_FRAMES:
return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
if np.all(frames == 0):
return "Error: All frames are blank", None
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
result = "FIGHT" if prediction >= 0.61 else "NORMAL"
confidence = min(max(abs(prediction - 0.61) * 150 + 50, 0), 100)
if debug:
self._debug_visualization(frames, prediction, result, video_path)
return f"{result} ({confidence:.1f}% confidence)", prediction
except Exception as e:
return f"Prediction error: {str(e)}", None
def _debug_visualization(self, frames, score, result, video_path):
print(f"\nPrediction Score: {score:.4f}")
print(f"Decision: {result}")
plt.figure(figsize=(15, 5))
for i in range(min(10, len(frames))):
plt.subplot(2, 5, i+1)
plt.imshow(frames[i])
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
plt.axis('off')
plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
plt.tight_layout()
# Save the visualization
base_name = os.path.splitext(os.path.basename(video_path))[0]
save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
plt.savefig(save_path)
plt.close()
print(f"Visualization saved to: {save_path}")
detector = FightDetector()
if detector.model is None:
return "Error: Model loading failed", None
return detector.predict(video_path)
# # Entry point
# path0 = input("Enter the local path to the video file to detect fight: ")
# path = path0.strip('"') # Remove extra quotes if copied from Windows
# print(f"[INFO] Loading video: {path}")
# fight_detec(path)
|