File size: 8,722 Bytes
e06363f
da07a7d
e06363f
 
da07a7d
 
 
 
e06363f
 
f8bfcd7
da07a7d
 
e06363f
adc4b03
 
 
e06363f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da07a7d
 
e06363f
da07a7d
e06363f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

import tensorflow as tf
from frame_slicer import extract_video_frames
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt

# Configuration
import os
MODEL_PATH = os.path.join(os.path.dirname(__file__),"final_model_2.h5")
N_FRAMES = 30
IMG_SIZE = (96, 96)
# Define RESULT_PATH relative to the script location
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")

def fight_detec(video_path: str, debug: bool = True):
    """Detects fight in a video and returns the result string and raw prediction score."""

    class FightDetector:
        def __init__(self):
            self.model = self._load_model()

        def _load_model(self):
            # Ensure the model path exists before loading
            if not os.path.exists(MODEL_PATH):
                print(f"Error: Model file not found at {MODEL_PATH}")
                return None
            try:
                # Load model with compile=False if optimizer state isn't needed for inference
                model = tf.keras.models.load_model(MODEL_PATH, compile=False)
                if debug:
                    print("\nModel loaded successfully. Input shape:", model.input_shape)
                return model
            except Exception as e:
                print(f"Model loading failed: {e}")
                return None

        def _extract_frames(self, video_path):
            frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
            if frames is None:
                print(f"Frame extraction returned None for {video_path}")
                return None

            if debug:
                blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
                if blank_frames > 0:
                    print(f"Warning: {blank_frames} blank frames detected")
                # Save a sample frame for debugging only if debug is True
                if frames.shape[0] > 0 and not np.all(frames[0] == 0): # Avoid saving blank frame
                    sample_frame = (frames[0] * 255).astype(np.uint8)
                    try:
                        os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists
                        debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
                        cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
                        print(f"Debug frame saved to {debug_frame_path}")
                    except Exception as e:
                        print(f"Failed to save debug frame: {e}")
                else:
                     print("Skipping debug frame save (first frame blank or no frames).")


            return frames

        def predict(self, video_path):
            if not os.path.exists(video_path):
                print(f"Error: Video not found at {video_path}")
                return "Error: Video not found", None

            try:
                frames = self._extract_frames(video_path)
                if frames is None:
                    return "Error: Frame extraction failed", None

                if frames.shape[0] != N_FRAMES:
                    # Pad with last frame or zeros if not enough frames were extracted
                    print(f"Warning: Expected {N_FRAMES} frames, got {frames.shape[0]}. Padding...")
                    if frames.shape[0] == 0: # No frames at all
                         frames = np.zeros((N_FRAMES, *IMG_SIZE, 3), dtype=np.float32)
                    else: # Pad with the last available frame
                        padding_needed = N_FRAMES - frames.shape[0]
                        last_frame = frames[-1][np.newaxis, ...]
                        padding = np.repeat(last_frame, padding_needed, axis=0)
                        frames = np.concatenate((frames, padding), axis=0)
                    print(f"Frames padded to shape: {frames.shape}")


                if np.all(frames == 0):
                    # Check if all frames are actually blank (can happen with padding)
                    print("Error: All frames are blank after processing/padding.")
                    return "Error: All frames are blank", None

                # Perform prediction
                prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
                # Determine result based on threshold
                threshold = 0.61 # Example threshold
                is_fight = prediction >= threshold
                result = "FIGHT" if is_fight else "NORMAL"

                # Calculate confidence (simple distance from threshold, scaled)
                # Adjust scaling factor (e.g., 150) and base (e.g., 50) as needed
                # Ensure confidence reflects certainty (higher for values far from threshold)
                if is_fight:
                    confidence = min(max((prediction - threshold) * 150 + 50, 0), 100)
                else:
                    confidence = min(max((threshold - prediction) * 150 + 50, 0), 100)


                result_string = f"{result} ({confidence:.1f}% confidence)"

                if debug:
                    print(f"Raw Prediction Score: {prediction:.4f}")
                    self._debug_visualization(frames, prediction, result_string, video_path)

                return result_string, float(prediction) # Return string and raw score

            except Exception as e:
                print(f"Prediction error: {str(e)}")
                # Consider logging the full traceback here in a real application
                # import traceback
                # print(traceback.format_exc())
                return f"Prediction error: {str(e)}", None

        def _debug_visualization(self, frames, score, result, video_path):
            # This function will only run if debug=True is passed to fight_detec
            print(f"\n--- Debug Visualization ---")
            print(f"Prediction Score: {score:.4f}")
            print(f"Decision: {result}")

            # Avoid plotting if matplotlib is not available or causes issues in deployment
            try:
                import matplotlib.pyplot as plt
                plt.figure(figsize=(15, 5))
                num_frames_to_show = min(10, len(frames))
                for i in range(num_frames_to_show):
                    plt.subplot(2, 5, i+1)
                    # Ensure frame values are valid for imshow (0-1 or 0-255)
                    img_display = frames[i]
                    if np.max(img_display) <= 1.0: # Assuming normalized float [0,1]
                         img_display = (img_display * 255).astype(np.uint8)
                    else: # Assuming it might already be uint8 [0,255]
                         img_display = img_display.astype(np.uint8)

                    plt.imshow(img_display)
                    plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}") # Use original frame for mean
                    plt.axis('off')
                plt.suptitle(f"Video: {os.path.basename(video_path)}\nPrediction: {result} (Raw Score: {score:.4f})")
                plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout

                # Save the visualization
                os.makedirs(RESULT_PATH, exist_ok=True) # Ensure result path exists again
                base_name = os.path.splitext(os.path.basename(video_path))[0]
                save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
                plt.savefig(save_path)
                plt.close() # Close the plot to free memory
                print(f"Debug visualization saved to: {save_path}")
            except ImportError:
                print("Matplotlib not found. Skipping debug visualization plot.")
            except Exception as e:
                print(f"Error during debug visualization: {e}")
            print("--- End Debug Visualization ---")


    # --- Main function logic ---
    detector = FightDetector()
    if detector.model is None:
        # Model loading failed, return error
        return "Error: Model loading failed", None

    # Call the predict method
    result_str, prediction_score = detector.predict(video_path)
    return result_str, prediction_score


# # Example usage (commented out for library use)
# if __name__ == "__main__":
#     # Example of how to call the function
#     test_video = input("Enter the local path to the video file: ").strip('"')
#     if os.path.exists(test_video):
#         print(f"[INFO] Processing video: {test_video}")
#         result, score = fight_detec(test_video, debug=True) # Enable debug for local testing
#         print(f"\nFinal Result: {result}")
#         if score is not None:
#             print(f"Raw Score: {score:.4f}")
#     else:
#         print(f"Error: File not found - {test_video}")