KillD00zer commited on
Commit
e06363f
·
verified ·
1 Parent(s): 8de568b

Update Fight_detec_func.py

Browse files
Files changed (1) hide show
  1. Fight_detec_func.py +176 -99
Fight_detec_func.py CHANGED
@@ -1,112 +1,189 @@
 
 
1
  import tensorflow as tf
 
 
2
  import os
3
  import numpy as np
4
- import cv2
5
- import logging
6
- from datetime import datetime
7
  import matplotlib.pyplot as plt
8
 
9
- from frame_slicer import extract_video_frames # Make sure this module is in PYTHONPATH
10
-
11
- # ----------------- Configuration ----------------- #
12
- MODEL_PATH = os.path.join(os.path.dirname(__file__), "training_output", "final_model_2.h5")
13
  N_FRAMES = 30
14
  IMG_SIZE = (96, 96)
 
15
  RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
16
- FIGHT_THRESHOLD = 0.61
17
-
18
- os.makedirs(RESULT_PATH, exist_ok=True)
19
-
20
- # ----------------- Logging Setup ----------------- #
21
- logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
22
- logger = logging.getLogger(__name__)
23
-
24
- # ----------------- Main Detector Class ----------------- #
25
- class FightDetector:
26
- def __init__(self):
27
- self.model = self._load_model()
28
-
29
- def _load_model(self):
30
- try:
31
- model = tf.keras.models.load_model(MODEL_PATH, compile=False)
32
- logger.info(f"Model loaded successfully. Input shape: {model.input_shape}")
33
- return model
34
- except Exception as e:
35
- logger.error(f"Model loading failed: {e}")
36
- return None
37
-
38
- def _extract_frames(self, video_path):
39
- frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
40
- if frames is None:
41
- return None
42
-
43
- blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
44
- if blank_frames > 0:
45
- logger.warning(f"{blank_frames} blank frames detected.")
46
-
47
- sample_frame = (frames[0] * 255).astype(np.uint8)
48
- debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
49
- cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
50
- logger.info(f"Debug frame saved to: {debug_frame_path}")
51
-
52
- return frames
53
-
54
- def predict(self, video_path):
55
- if not os.path.exists(video_path):
56
- return "Error: Video not found", None
57
-
58
- try:
59
- frames = self._extract_frames(video_path)
60
- if frames is None:
61
- return "Error: Frame extraction failed", None
62
-
63
- if frames.shape[0] != N_FRAMES:
64
- return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
65
-
66
- if np.all(frames == 0):
67
- return "Error: All frames are blank", None
68
 
69
- prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
70
- result = "FIGHT" if prediction >= FIGHT_THRESHOLD else "NORMAL"
71
- confidence = min(max(abs(prediction - FIGHT_THRESHOLD) * 150 + 50, 0), 100)
72
-
73
- self._debug_visualization(frames, prediction, result, video_path)
74
-
75
- return f"{result} ({confidence:.1f}% confidence)", prediction
76
-
77
- except Exception as e:
78
- return f"Prediction error: {str(e)}", None
79
-
80
- def _debug_visualization(self, frames, score, result, video_path):
81
- logger.info(f"Prediction Score: {score:.4f}")
82
- logger.info(f"Decision: {result}")
83
- plt.figure(figsize=(15, 5))
84
- for i in range(min(10, len(frames))):
85
- plt.subplot(2, 5, i + 1)
86
- plt.imshow(frames[i])
87
- plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
88
- plt.axis('off')
89
- plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
90
- plt.tight_layout()
91
-
92
- base_name = os.path.splitext(os.path.basename(video_path))[0]
93
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
94
- save_path = os.path.join(RESULT_PATH, f"{base_name}_result_{timestamp}.png")
95
- plt.savefig(save_path)
96
- plt.close()
97
- logger.info(f"Visualization saved to: {save_path}")
98
-
99
- # ----------------- External Interface ----------------- #
100
  def fight_detec(video_path: str, debug: bool = True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  detector = FightDetector()
102
  if detector.model is None:
 
103
  return "Error: Model loading failed", None
104
- return detector.predict(video_path)
105
-
106
- # ----------------- Optional CLI Entry ----------------- #
107
- if __name__ == "__main__":
108
- path0 = input("Enter the local path to the video file to detect fight: ")
109
- path = path0.strip('"') # Remove extra quotes if copied from Windows
110
- logger.info(f"Loading video: {path}")
111
- result, score = fight_detec(path)
112
- logger.info(f"Result: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --- START OF FILE Fight_detec_func.py ---
2
+
3
  import tensorflow as tf
4
+ from frame_slicer import extract_video_frames
5
+ import cv2
6
  import os
7
  import numpy as np
 
 
 
8
  import matplotlib.pyplot as plt
9
 
10
+ # Configuration
11
+ import os
12
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), "trainnig_output", "final_model_2.h5")
 
13
  N_FRAMES = 30
14
  IMG_SIZE = (96, 96)
15
+ # Define RESULT_PATH relative to the script location
16
  RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def fight_detec(video_path: str, debug: bool = True):
19
+ """Detects fight in a video and returns the result string and raw prediction score."""
20
+
21
+ class FightDetector:
22
+ def __init__(self):
23
+ self.model = self._load_model()
24
+
25
+ def _load_model(self):
26
+ # Ensure the model path exists before loading
27
+ if not os.path.exists(MODEL_PATH):
28
+ print(f"Error: Model file not found at {MODEL_PATH}")
29
+ return None
30
+ try:
31
+ # Load model with compile=False if optimizer state isn't needed for inference
32
+ model = tf.keras.models.load_model(MODEL_PATH, compile=False)
33
+ if debug:
34
+ print("\nModel loaded successfully. Input shape:", model.input_shape)
35
+ return model
36
+ except Exception as e:
37
+ print(f"Model loading failed: {e}")
38
+ return None
39
+
40
+ def _extract_frames(self, video_path):
41
+ frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
42
+ if frames is None:
43
+ print(f"Frame extraction returned None for {video_path}")
44
+ return None
45
+
46
+ if debug:
47
+ blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
48
+ if blank_frames > 0:
49
+ print(f"Warning: {blank_frames} blank frames detected")
50
+ # Save a sample frame for debugging only if debug is True
51
+ if frames.shape[0] > 0 and not np.all(frames[0] == 0): # Avoid saving blank frame
52
+ sample_frame = (frames[0] * 255).astype(np.uint8)
53
+ try:
54
+ os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists
55
+ debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
56
+ cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
57
+ print(f"Debug frame saved to {debug_frame_path}")
58
+ except Exception as e:
59
+ print(f"Failed to save debug frame: {e}")
60
+ else:
61
+ print("Skipping debug frame save (first frame blank or no frames).")
62
+
63
+
64
+ return frames
65
+
66
+ def predict(self, video_path):
67
+ if not os.path.exists(video_path):
68
+ print(f"Error: Video not found at {video_path}")
69
+ return "Error: Video not found", None
70
+
71
+ try:
72
+ frames = self._extract_frames(video_path)
73
+ if frames is None:
74
+ return "Error: Frame extraction failed", None
75
+
76
+ if frames.shape[0] != N_FRAMES:
77
+ # Pad with last frame or zeros if not enough frames were extracted
78
+ print(f"Warning: Expected {N_FRAMES} frames, got {frames.shape[0]}. Padding...")
79
+ if frames.shape[0] == 0: # No frames at all
80
+ frames = np.zeros((N_FRAMES, *IMG_SIZE, 3), dtype=np.float32)
81
+ else: # Pad with the last available frame
82
+ padding_needed = N_FRAMES - frames.shape[0]
83
+ last_frame = frames[-1][np.newaxis, ...]
84
+ padding = np.repeat(last_frame, padding_needed, axis=0)
85
+ frames = np.concatenate((frames, padding), axis=0)
86
+ print(f"Frames padded to shape: {frames.shape}")
87
+
88
+
89
+ if np.all(frames == 0):
90
+ # Check if all frames are actually blank (can happen with padding)
91
+ print("Error: All frames are blank after processing/padding.")
92
+ return "Error: All frames are blank", None
93
+
94
+ # Perform prediction
95
+ prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
96
+ # Determine result based on threshold
97
+ threshold = 0.61 # Example threshold
98
+ is_fight = prediction >= threshold
99
+ result = "FIGHT" if is_fight else "NORMAL"
100
+
101
+ # Calculate confidence (simple distance from threshold, scaled)
102
+ # Adjust scaling factor (e.g., 150) and base (e.g., 50) as needed
103
+ # Ensure confidence reflects certainty (higher for values far from threshold)
104
+ if is_fight:
105
+ confidence = min(max((prediction - threshold) * 150 + 50, 0), 100)
106
+ else:
107
+ confidence = min(max((threshold - prediction) * 150 + 50, 0), 100)
108
+
109
+
110
+ result_string = f"{result} ({confidence:.1f}% confidence)"
111
+
112
+ if debug:
113
+ print(f"Raw Prediction Score: {prediction:.4f}")
114
+ self._debug_visualization(frames, prediction, result_string, video_path)
115
+
116
+ return result_string, float(prediction) # Return string and raw score
117
+
118
+ except Exception as e:
119
+ print(f"Prediction error: {str(e)}")
120
+ # Consider logging the full traceback here in a real application
121
+ # import traceback
122
+ # print(traceback.format_exc())
123
+ return f"Prediction error: {str(e)}", None
124
+
125
+ def _debug_visualization(self, frames, score, result, video_path):
126
+ # This function will only run if debug=True is passed to fight_detec
127
+ print(f"\n--- Debug Visualization ---")
128
+ print(f"Prediction Score: {score:.4f}")
129
+ print(f"Decision: {result}")
130
+
131
+ # Avoid plotting if matplotlib is not available or causes issues in deployment
132
+ try:
133
+ import matplotlib.pyplot as plt
134
+ plt.figure(figsize=(15, 5))
135
+ num_frames_to_show = min(10, len(frames))
136
+ for i in range(num_frames_to_show):
137
+ plt.subplot(2, 5, i+1)
138
+ # Ensure frame values are valid for imshow (0-1 or 0-255)
139
+ img_display = frames[i]
140
+ if np.max(img_display) <= 1.0: # Assuming normalized float [0,1]
141
+ img_display = (img_display * 255).astype(np.uint8)
142
+ else: # Assuming it might already be uint8 [0,255]
143
+ img_display = img_display.astype(np.uint8)
144
+
145
+ plt.imshow(img_display)
146
+ plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}") # Use original frame for mean
147
+ plt.axis('off')
148
+ plt.suptitle(f"Video: {os.path.basename(video_path)}\nPrediction: {result} (Raw Score: {score:.4f})")
149
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
150
+
151
+ # Save the visualization
152
+ os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists again
153
+ base_name = os.path.splitext(os.path.basename(video_path))[0]
154
+ save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
155
+ plt.savefig(save_path)
156
+ plt.close() # Close the plot to free memory
157
+ print(f"Debug visualization saved to: {save_path}")
158
+ except ImportError:
159
+ print("Matplotlib not found. Skipping debug visualization plot.")
160
+ except Exception as e:
161
+ print(f"Error during debug visualization: {e}")
162
+ print("--- End Debug Visualization ---")
163
+
164
+
165
+ # --- Main function logic ---
166
  detector = FightDetector()
167
  if detector.model is None:
168
+ # Model loading failed, return error
169
  return "Error: Model loading failed", None
170
+
171
+ # Call the predict method
172
+ result_str, prediction_score = detector.predict(video_path)
173
+ return result_str, prediction_score
174
+
175
+
176
+ # # Example usage (commented out for library use)
177
+ # if __name__ == "__main__":
178
+ # # Example of how to call the function
179
+ # test_video = input("Enter the local path to the video file: ").strip('"')
180
+ # if os.path.exists(test_video):
181
+ # print(f"[INFO] Processing video: {test_video}")
182
+ # result, score = fight_detec(test_video, debug=True) # Enable debug for local testing
183
+ # print(f"\nFinal Result: {result}")
184
+ # if score is not None:
185
+ # print(f"Raw Score: {score:.4f}")
186
+ # else:
187
+ # print(f"Error: File not found - {test_video}")
188
+
189
+ --- END OF FILE Fight_detec_func.py ---