KillD00zer commited on
Commit
adc4b03
·
verified ·
1 Parent(s): 61bcdf4

Update Fight_detec_func.py

Browse files
Files changed (1) hide show
  1. Fight_detec_func.py +97 -88
Fight_detec_func.py CHANGED
@@ -1,103 +1,112 @@
1
  import tensorflow as tf
2
- from frame_slicer import extract_video_frames
3
- import cv2
4
  import os
5
  import numpy as np
 
 
 
6
  import matplotlib.pyplot as plt
7
 
8
- # Configuration
9
- import os
10
- MODEL_PATH = os.path.join(os.path.dirname(__file__), "trainnig_output", "final_model_2.h5")
 
11
  N_FRAMES = 30
12
  IMG_SIZE = (96, 96)
13
- RESULT_PATH = os.path.join(os.path.dirname(__file__), "results") # Will be created if doesn't exist
 
14
 
15
- def fight_detec(video_path: str, debug: bool = True):
16
- """Detects fight in a video and returns the result and confidence score."""
17
-
18
- class FightDetector:
19
- def __init__(self):
20
- self.model = self._load_model()
21
-
22
- def _load_model(self):
23
- try:
24
- model = tf.keras.models.load_model(MODEL_PATH, compile=False)
25
- if debug:
26
- print("\nModel loaded successfully. Input shape:", model.input_shape)
27
- return model
28
- except Exception as e:
29
- print(f"Model loading failed: {e}")
30
- return None
31
-
32
- def _extract_frames(self, video_path):
33
- frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if frames is None:
35
- return None
36
-
37
- if debug:
38
- blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
39
- if blank_frames > 0:
40
- print(f"Warning: {blank_frames} blank frames detected")
41
- sample_frame = (frames[0] * 255).astype(np.uint8)
42
- os.makedirs(RESULT_PATH, exist_ok=True)
43
- cv2.imwrite(os.path.join(RESULT_PATH, 'debug_frame.jpg'),
44
- cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
45
-
46
- return frames
47
-
48
- def predict(self, video_path):
49
- if not os.path.exists(video_path):
50
- return "Error: Video not found", None
51
-
52
- try:
53
- frames = self._extract_frames(video_path)
54
- if frames is None:
55
- return "Error: Frame extraction failed", None
56
-
57
- if frames.shape[0] != N_FRAMES:
58
- return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
59
-
60
- if np.all(frames == 0):
61
- return "Error: All frames are blank", None
62
-
63
- prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
64
- result = "FIGHT" if prediction >= 0.61 else "NORMAL"
65
- confidence = min(max(abs(prediction - 0.61) * 150 + 50, 0), 100)
66
-
67
- if debug:
68
- self._debug_visualization(frames, prediction, result, video_path)
69
-
70
- return f"{result} ({confidence:.1f}% confidence)", prediction
71
-
72
- except Exception as e:
73
- return f"Prediction error: {str(e)}", None
74
-
75
- def _debug_visualization(self, frames, score, result, video_path):
76
- print(f"\nPrediction Score: {score:.4f}")
77
- print(f"Decision: {result}")
78
- plt.figure(figsize=(15, 5))
79
- for i in range(min(10, len(frames))):
80
- plt.subplot(2, 5, i+1)
81
- plt.imshow(frames[i])
82
- plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
83
- plt.axis('off')
84
- plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
85
- plt.tight_layout()
86
-
87
- # Save the visualization
88
- base_name = os.path.splitext(os.path.basename(video_path))[0]
89
- save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
90
- plt.savefig(save_path)
91
- plt.close()
92
- print(f"Visualization saved to: {save_path}")
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  detector = FightDetector()
95
  if detector.model is None:
96
  return "Error: Model loading failed", None
97
  return detector.predict(video_path)
98
 
99
- # # Entry point
100
- # path0 = input("Enter the local path to the video file to detect fight: ")
101
- # path = path0.strip('"') # Remove extra quotes if copied from Windows
102
- # print(f"[INFO] Loading video: {path}")
103
- # fight_detec(path)
 
 
 
1
  import tensorflow as tf
 
 
2
  import os
3
  import numpy as np
4
+ import cv2
5
+ import logging
6
+ from datetime import datetime
7
  import matplotlib.pyplot as plt
8
 
9
+ from frame_slicer import extract_video_frames # Make sure this module is in PYTHONPATH
10
+
11
+ # ----------------- Configuration ----------------- #
12
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), "training_output", "final_model_2.h5")
13
  N_FRAMES = 30
14
  IMG_SIZE = (96, 96)
15
+ RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
16
+ FIGHT_THRESHOLD = 0.61
17
 
18
+ os.makedirs(RESULT_PATH, exist_ok=True)
19
+
20
+ # ----------------- Logging Setup ----------------- #
21
+ logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # ----------------- Main Detector Class ----------------- #
25
+ class FightDetector:
26
+ def __init__(self):
27
+ self.model = self._load_model()
28
+
29
+ def _load_model(self):
30
+ try:
31
+ model = tf.keras.models.load_model(MODEL_PATH, compile=False)
32
+ logger.info(f"Model loaded successfully. Input shape: {model.input_shape}")
33
+ return model
34
+ except Exception as e:
35
+ logger.error(f"Model loading failed: {e}")
36
+ return None
37
+
38
+ def _extract_frames(self, video_path):
39
+ frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
40
+ if frames is None:
41
+ return None
42
+
43
+ blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
44
+ if blank_frames > 0:
45
+ logger.warning(f"{blank_frames} blank frames detected.")
46
+
47
+ sample_frame = (frames[0] * 255).astype(np.uint8)
48
+ debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
49
+ cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
50
+ logger.info(f"Debug frame saved to: {debug_frame_path}")
51
+
52
+ return frames
53
+
54
+ def predict(self, video_path):
55
+ if not os.path.exists(video_path):
56
+ return "Error: Video not found", None
57
+
58
+ try:
59
+ frames = self._extract_frames(video_path)
60
  if frames is None:
61
+ return "Error: Frame extraction failed", None
62
+
63
+ if frames.shape[0] != N_FRAMES:
64
+ return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
65
+
66
+ if np.all(frames == 0):
67
+ return "Error: All frames are blank", None
68
+
69
+ prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
70
+ result = "FIGHT" if prediction >= FIGHT_THRESHOLD else "NORMAL"
71
+ confidence = min(max(abs(prediction - FIGHT_THRESHOLD) * 150 + 50, 0), 100)
72
+
73
+ self._debug_visualization(frames, prediction, result, video_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ return f"{result} ({confidence:.1f}% confidence)", prediction
76
+
77
+ except Exception as e:
78
+ return f"Prediction error: {str(e)}", None
79
+
80
+ def _debug_visualization(self, frames, score, result, video_path):
81
+ logger.info(f"Prediction Score: {score:.4f}")
82
+ logger.info(f"Decision: {result}")
83
+ plt.figure(figsize=(15, 5))
84
+ for i in range(min(10, len(frames))):
85
+ plt.subplot(2, 5, i + 1)
86
+ plt.imshow(frames[i])
87
+ plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
88
+ plt.axis('off')
89
+ plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
90
+ plt.tight_layout()
91
+
92
+ base_name = os.path.splitext(os.path.basename(video_path))[0]
93
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
94
+ save_path = os.path.join(RESULT_PATH, f"{base_name}_result_{timestamp}.png")
95
+ plt.savefig(save_path)
96
+ plt.close()
97
+ logger.info(f"Visualization saved to: {save_path}")
98
+
99
+ # ----------------- External Interface ----------------- #
100
+ def fight_detec(video_path: str, debug: bool = True):
101
  detector = FightDetector()
102
  if detector.model is None:
103
  return "Error: Model loading failed", None
104
  return detector.predict(video_path)
105
 
106
+ # ----------------- Optional CLI Entry ----------------- #
107
+ if __name__ == "__main__":
108
+ path0 = input("Enter the local path to the video file to detect fight: ")
109
+ path = path0.strip('"') # Remove extra quotes if copied from Windows
110
+ logger.info(f"Loading video: {path}")
111
+ result, score = fight_detec(path)
112
+ logger.info(f"Result: {result}")