File size: 8,722 Bytes
e06363f da07a7d e06363f da07a7d e06363f f8bfcd7 da07a7d e06363f adc4b03 e06363f da07a7d e06363f da07a7d e06363f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import tensorflow as tf
from frame_slicer import extract_video_frames
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
# Configuration
import os
MODEL_PATH = os.path.join(os.path.dirname(__file__),"final_model_2.h5")
N_FRAMES = 30
IMG_SIZE = (96, 96)
# Define RESULT_PATH relative to the script location
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
def fight_detec(video_path: str, debug: bool = True):
"""Detects fight in a video and returns the result string and raw prediction score."""
class FightDetector:
def __init__(self):
self.model = self._load_model()
def _load_model(self):
# Ensure the model path exists before loading
if not os.path.exists(MODEL_PATH):
print(f"Error: Model file not found at {MODEL_PATH}")
return None
try:
# Load model with compile=False if optimizer state isn't needed for inference
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
if debug:
print("\nModel loaded successfully. Input shape:", model.input_shape)
return model
except Exception as e:
print(f"Model loading failed: {e}")
return None
def _extract_frames(self, video_path):
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
if frames is None:
print(f"Frame extraction returned None for {video_path}")
return None
if debug:
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
if blank_frames > 0:
print(f"Warning: {blank_frames} blank frames detected")
# Save a sample frame for debugging only if debug is True
if frames.shape[0] > 0 and not np.all(frames[0] == 0): # Avoid saving blank frame
sample_frame = (frames[0] * 255).astype(np.uint8)
try:
os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists
debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
print(f"Debug frame saved to {debug_frame_path}")
except Exception as e:
print(f"Failed to save debug frame: {e}")
else:
print("Skipping debug frame save (first frame blank or no frames).")
return frames
def predict(self, video_path):
if not os.path.exists(video_path):
print(f"Error: Video not found at {video_path}")
return "Error: Video not found", None
try:
frames = self._extract_frames(video_path)
if frames is None:
return "Error: Frame extraction failed", None
if frames.shape[0] != N_FRAMES:
# Pad with last frame or zeros if not enough frames were extracted
print(f"Warning: Expected {N_FRAMES} frames, got {frames.shape[0]}. Padding...")
if frames.shape[0] == 0: # No frames at all
frames = np.zeros((N_FRAMES, *IMG_SIZE, 3), dtype=np.float32)
else: # Pad with the last available frame
padding_needed = N_FRAMES - frames.shape[0]
last_frame = frames[-1][np.newaxis, ...]
padding = np.repeat(last_frame, padding_needed, axis=0)
frames = np.concatenate((frames, padding), axis=0)
print(f"Frames padded to shape: {frames.shape}")
if np.all(frames == 0):
# Check if all frames are actually blank (can happen with padding)
print("Error: All frames are blank after processing/padding.")
return "Error: All frames are blank", None
# Perform prediction
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
# Determine result based on threshold
threshold = 0.61 # Example threshold
is_fight = prediction >= threshold
result = "FIGHT" if is_fight else "NORMAL"
# Calculate confidence (simple distance from threshold, scaled)
# Adjust scaling factor (e.g., 150) and base (e.g., 50) as needed
# Ensure confidence reflects certainty (higher for values far from threshold)
if is_fight:
confidence = min(max((prediction - threshold) * 150 + 50, 0), 100)
else:
confidence = min(max((threshold - prediction) * 150 + 50, 0), 100)
result_string = f"{result} ({confidence:.1f}% confidence)"
if debug:
print(f"Raw Prediction Score: {prediction:.4f}")
self._debug_visualization(frames, prediction, result_string, video_path)
return result_string, float(prediction) # Return string and raw score
except Exception as e:
print(f"Prediction error: {str(e)}")
# Consider logging the full traceback here in a real application
# import traceback
# print(traceback.format_exc())
return f"Prediction error: {str(e)}", None
def _debug_visualization(self, frames, score, result, video_path):
# This function will only run if debug=True is passed to fight_detec
print(f"\n--- Debug Visualization ---")
print(f"Prediction Score: {score:.4f}")
print(f"Decision: {result}")
# Avoid plotting if matplotlib is not available or causes issues in deployment
try:
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
num_frames_to_show = min(10, len(frames))
for i in range(num_frames_to_show):
plt.subplot(2, 5, i+1)
# Ensure frame values are valid for imshow (0-1 or 0-255)
img_display = frames[i]
if np.max(img_display) <= 1.0: # Assuming normalized float [0,1]
img_display = (img_display * 255).astype(np.uint8)
else: # Assuming it might already be uint8 [0,255]
img_display = img_display.astype(np.uint8)
plt.imshow(img_display)
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}") # Use original frame for mean
plt.axis('off')
plt.suptitle(f"Video: {os.path.basename(video_path)}\nPrediction: {result} (Raw Score: {score:.4f})")
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
# Save the visualization
os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists again
base_name = os.path.splitext(os.path.basename(video_path))[0]
save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
plt.savefig(save_path)
plt.close() # Close the plot to free memory
print(f"Debug visualization saved to: {save_path}")
except ImportError:
print("Matplotlib not found. Skipping debug visualization plot.")
except Exception as e:
print(f"Error during debug visualization: {e}")
print("--- End Debug Visualization ---")
# --- Main function logic ---
detector = FightDetector()
if detector.model is None:
# Model loading failed, return error
return "Error: Model loading failed", None
# Call the predict method
result_str, prediction_score = detector.predict(video_path)
return result_str, prediction_score
# # Example usage (commented out for library use)
# if __name__ == "__main__":
# # Example of how to call the function
# test_video = input("Enter the local path to the video file: ").strip('"')
# if os.path.exists(test_video):
# print(f"[INFO] Processing video: {test_video}")
# result, score = fight_detec(test_video, debug=True) # Enable debug for local testing
# print(f"\nFinal Result: {result}")
# if score is not None:
# print(f"Raw Score: {score:.4f}")
# else:
# print(f"Error: File not found - {test_video}")
|