|
|
|
import tensorflow as tf |
|
from frame_slicer import extract_video_frames |
|
import cv2 |
|
import os |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import os |
|
MODEL_PATH = os.path.join(os.path.dirname(__file__),"final_model_2.h5") |
|
N_FRAMES = 30 |
|
IMG_SIZE = (96, 96) |
|
|
|
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results") |
|
|
|
def fight_detec(video_path: str, debug: bool = True): |
|
"""Detects fight in a video and returns the result string and raw prediction score.""" |
|
|
|
class FightDetector: |
|
def __init__(self): |
|
self.model = self._load_model() |
|
|
|
def _load_model(self): |
|
|
|
if not os.path.exists(MODEL_PATH): |
|
print(f"Error: Model file not found at {MODEL_PATH}") |
|
return None |
|
try: |
|
|
|
model = tf.keras.models.load_model(MODEL_PATH, compile=False) |
|
if debug: |
|
print("\nModel loaded successfully. Input shape:", model.input_shape) |
|
return model |
|
except Exception as e: |
|
print(f"Model loading failed: {e}") |
|
return None |
|
|
|
def _extract_frames(self, video_path): |
|
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE) |
|
if frames is None: |
|
print(f"Frame extraction returned None for {video_path}") |
|
return None |
|
|
|
if debug: |
|
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum() |
|
if blank_frames > 0: |
|
print(f"Warning: {blank_frames} blank frames detected") |
|
|
|
if frames.shape[0] > 0 and not np.all(frames[0] == 0): |
|
sample_frame = (frames[0] * 255).astype(np.uint8) |
|
try: |
|
os.makedirs(RESULT_PATH, exist_ok=True) |
|
debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg') |
|
cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR)) |
|
print(f"Debug frame saved to {debug_frame_path}") |
|
except Exception as e: |
|
print(f"Failed to save debug frame: {e}") |
|
else: |
|
print("Skipping debug frame save (first frame blank or no frames).") |
|
|
|
|
|
return frames |
|
|
|
def predict(self, video_path): |
|
if not os.path.exists(video_path): |
|
print(f"Error: Video not found at {video_path}") |
|
return "Error: Video not found", None |
|
|
|
try: |
|
frames = self._extract_frames(video_path) |
|
if frames is None: |
|
return "Error: Frame extraction failed", None |
|
|
|
if frames.shape[0] != N_FRAMES: |
|
|
|
print(f"Warning: Expected {N_FRAMES} frames, got {frames.shape[0]}. Padding...") |
|
if frames.shape[0] == 0: |
|
frames = np.zeros((N_FRAMES, *IMG_SIZE, 3), dtype=np.float32) |
|
else: |
|
padding_needed = N_FRAMES - frames.shape[0] |
|
last_frame = frames[-1][np.newaxis, ...] |
|
padding = np.repeat(last_frame, padding_needed, axis=0) |
|
frames = np.concatenate((frames, padding), axis=0) |
|
print(f"Frames padded to shape: {frames.shape}") |
|
|
|
|
|
if np.all(frames == 0): |
|
|
|
print("Error: All frames are blank after processing/padding.") |
|
return "Error: All frames are blank", None |
|
|
|
|
|
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0] |
|
|
|
threshold = 0.61 |
|
is_fight = prediction >= threshold |
|
result = "FIGHT" if is_fight else "NORMAL" |
|
|
|
|
|
|
|
|
|
if is_fight: |
|
confidence = min(max((prediction - threshold) * 150 + 50, 0), 100) |
|
else: |
|
confidence = min(max((threshold - prediction) * 150 + 50, 0), 100) |
|
|
|
|
|
result_string = f"{result} ({confidence:.1f}% confidence)" |
|
|
|
if debug: |
|
print(f"Raw Prediction Score: {prediction:.4f}") |
|
self._debug_visualization(frames, prediction, result_string, video_path) |
|
|
|
return result_string, float(prediction) |
|
|
|
except Exception as e: |
|
print(f"Prediction error: {str(e)}") |
|
|
|
|
|
|
|
return f"Prediction error: {str(e)}", None |
|
|
|
def _debug_visualization(self, frames, score, result, video_path): |
|
|
|
print(f"\n--- Debug Visualization ---") |
|
print(f"Prediction Score: {score:.4f}") |
|
print(f"Decision: {result}") |
|
|
|
|
|
try: |
|
import matplotlib.pyplot as plt |
|
plt.figure(figsize=(15, 5)) |
|
num_frames_to_show = min(10, len(frames)) |
|
for i in range(num_frames_to_show): |
|
plt.subplot(2, 5, i+1) |
|
|
|
img_display = frames[i] |
|
if np.max(img_display) <= 1.0: |
|
img_display = (img_display * 255).astype(np.uint8) |
|
else: |
|
img_display = img_display.astype(np.uint8) |
|
|
|
plt.imshow(img_display) |
|
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}") |
|
plt.axis('off') |
|
plt.suptitle(f"Video: {os.path.basename(video_path)}\nPrediction: {result} (Raw Score: {score:.4f})") |
|
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) |
|
|
|
|
|
os.makedirs(RESULT_PATH, exist_ok=True) |
|
base_name = os.path.splitext(os.path.basename(video_path))[0] |
|
save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png") |
|
plt.savefig(save_path) |
|
plt.close() |
|
print(f"Debug visualization saved to: {save_path}") |
|
except ImportError: |
|
print("Matplotlib not found. Skipping debug visualization plot.") |
|
except Exception as e: |
|
print(f"Error during debug visualization: {e}") |
|
print("--- End Debug Visualization ---") |
|
|
|
|
|
|
|
detector = FightDetector() |
|
if detector.model is None: |
|
|
|
return "Error: Model loading failed", None |
|
|
|
|
|
result_str, prediction_score = detector.predict(video_path) |
|
return result_str, prediction_score |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|