Update Fight_detec_func.py
Browse files- Fight_detec_func.py +97 -88
Fight_detec_func.py
CHANGED
@@ -1,103 +1,112 @@
|
|
1 |
import tensorflow as tf
|
2 |
-
from frame_slicer import extract_video_frames
|
3 |
-
import cv2
|
4 |
import os
|
5 |
import numpy as np
|
|
|
|
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
|
|
|
11 |
N_FRAMES = 30
|
12 |
IMG_SIZE = (96, 96)
|
13 |
-
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
|
|
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
if frames is None:
|
35 |
-
return None
|
36 |
-
|
37 |
-
if
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
def predict(self, video_path):
|
49 |
-
if not os.path.exists(video_path):
|
50 |
-
return "Error: Video not found", None
|
51 |
-
|
52 |
-
try:
|
53 |
-
frames = self._extract_frames(video_path)
|
54 |
-
if frames is None:
|
55 |
-
return "Error: Frame extraction failed", None
|
56 |
-
|
57 |
-
if frames.shape[0] != N_FRAMES:
|
58 |
-
return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
|
59 |
-
|
60 |
-
if np.all(frames == 0):
|
61 |
-
return "Error: All frames are blank", None
|
62 |
-
|
63 |
-
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
|
64 |
-
result = "FIGHT" if prediction >= 0.61 else "NORMAL"
|
65 |
-
confidence = min(max(abs(prediction - 0.61) * 150 + 50, 0), 100)
|
66 |
-
|
67 |
-
if debug:
|
68 |
-
self._debug_visualization(frames, prediction, result, video_path)
|
69 |
-
|
70 |
-
return f"{result} ({confidence:.1f}% confidence)", prediction
|
71 |
-
|
72 |
-
except Exception as e:
|
73 |
-
return f"Prediction error: {str(e)}", None
|
74 |
-
|
75 |
-
def _debug_visualization(self, frames, score, result, video_path):
|
76 |
-
print(f"\nPrediction Score: {score:.4f}")
|
77 |
-
print(f"Decision: {result}")
|
78 |
-
plt.figure(figsize=(15, 5))
|
79 |
-
for i in range(min(10, len(frames))):
|
80 |
-
plt.subplot(2, 5, i+1)
|
81 |
-
plt.imshow(frames[i])
|
82 |
-
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
|
83 |
-
plt.axis('off')
|
84 |
-
plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
|
85 |
-
plt.tight_layout()
|
86 |
-
|
87 |
-
# Save the visualization
|
88 |
-
base_name = os.path.splitext(os.path.basename(video_path))[0]
|
89 |
-
save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
|
90 |
-
plt.savefig(save_path)
|
91 |
-
plt.close()
|
92 |
-
print(f"Visualization saved to: {save_path}")
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
detector = FightDetector()
|
95 |
if detector.model is None:
|
96 |
return "Error: Model loading failed", None
|
97 |
return detector.predict(video_path)
|
98 |
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
1 |
import tensorflow as tf
|
|
|
|
|
2 |
import os
|
3 |
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import logging
|
6 |
+
from datetime import datetime
|
7 |
import matplotlib.pyplot as plt
|
8 |
|
9 |
+
from frame_slicer import extract_video_frames # Make sure this module is in PYTHONPATH
|
10 |
+
|
11 |
+
# ----------------- Configuration ----------------- #
|
12 |
+
MODEL_PATH = os.path.join(os.path.dirname(__file__), "training_output", "final_model_2.h5")
|
13 |
N_FRAMES = 30
|
14 |
IMG_SIZE = (96, 96)
|
15 |
+
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
|
16 |
+
FIGHT_THRESHOLD = 0.61
|
17 |
|
18 |
+
os.makedirs(RESULT_PATH, exist_ok=True)
|
19 |
+
|
20 |
+
# ----------------- Logging Setup ----------------- #
|
21 |
+
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
# ----------------- Main Detector Class ----------------- #
|
25 |
+
class FightDetector:
|
26 |
+
def __init__(self):
|
27 |
+
self.model = self._load_model()
|
28 |
+
|
29 |
+
def _load_model(self):
|
30 |
+
try:
|
31 |
+
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
|
32 |
+
logger.info(f"Model loaded successfully. Input shape: {model.input_shape}")
|
33 |
+
return model
|
34 |
+
except Exception as e:
|
35 |
+
logger.error(f"Model loading failed: {e}")
|
36 |
+
return None
|
37 |
+
|
38 |
+
def _extract_frames(self, video_path):
|
39 |
+
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
|
40 |
+
if frames is None:
|
41 |
+
return None
|
42 |
+
|
43 |
+
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
|
44 |
+
if blank_frames > 0:
|
45 |
+
logger.warning(f"{blank_frames} blank frames detected.")
|
46 |
+
|
47 |
+
sample_frame = (frames[0] * 255).astype(np.uint8)
|
48 |
+
debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
|
49 |
+
cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
|
50 |
+
logger.info(f"Debug frame saved to: {debug_frame_path}")
|
51 |
+
|
52 |
+
return frames
|
53 |
+
|
54 |
+
def predict(self, video_path):
|
55 |
+
if not os.path.exists(video_path):
|
56 |
+
return "Error: Video not found", None
|
57 |
+
|
58 |
+
try:
|
59 |
+
frames = self._extract_frames(video_path)
|
60 |
if frames is None:
|
61 |
+
return "Error: Frame extraction failed", None
|
62 |
+
|
63 |
+
if frames.shape[0] != N_FRAMES:
|
64 |
+
return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
|
65 |
+
|
66 |
+
if np.all(frames == 0):
|
67 |
+
return "Error: All frames are blank", None
|
68 |
+
|
69 |
+
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
|
70 |
+
result = "FIGHT" if prediction >= FIGHT_THRESHOLD else "NORMAL"
|
71 |
+
confidence = min(max(abs(prediction - FIGHT_THRESHOLD) * 150 + 50, 0), 100)
|
72 |
+
|
73 |
+
self._debug_visualization(frames, prediction, result, video_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
return f"{result} ({confidence:.1f}% confidence)", prediction
|
76 |
+
|
77 |
+
except Exception as e:
|
78 |
+
return f"Prediction error: {str(e)}", None
|
79 |
+
|
80 |
+
def _debug_visualization(self, frames, score, result, video_path):
|
81 |
+
logger.info(f"Prediction Score: {score:.4f}")
|
82 |
+
logger.info(f"Decision: {result}")
|
83 |
+
plt.figure(figsize=(15, 5))
|
84 |
+
for i in range(min(10, len(frames))):
|
85 |
+
plt.subplot(2, 5, i + 1)
|
86 |
+
plt.imshow(frames[i])
|
87 |
+
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
|
88 |
+
plt.axis('off')
|
89 |
+
plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
|
90 |
+
plt.tight_layout()
|
91 |
+
|
92 |
+
base_name = os.path.splitext(os.path.basename(video_path))[0]
|
93 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
94 |
+
save_path = os.path.join(RESULT_PATH, f"{base_name}_result_{timestamp}.png")
|
95 |
+
plt.savefig(save_path)
|
96 |
+
plt.close()
|
97 |
+
logger.info(f"Visualization saved to: {save_path}")
|
98 |
+
|
99 |
+
# ----------------- External Interface ----------------- #
|
100 |
+
def fight_detec(video_path: str, debug: bool = True):
|
101 |
detector = FightDetector()
|
102 |
if detector.model is None:
|
103 |
return "Error: Model loading failed", None
|
104 |
return detector.predict(video_path)
|
105 |
|
106 |
+
# ----------------- Optional CLI Entry ----------------- #
|
107 |
+
if __name__ == "__main__":
|
108 |
+
path0 = input("Enter the local path to the video file to detect fight: ")
|
109 |
+
path = path0.strip('"') # Remove extra quotes if copied from Windows
|
110 |
+
logger.info(f"Loading video: {path}")
|
111 |
+
result, score = fight_detec(path)
|
112 |
+
logger.info(f"Result: {result}")
|