File size: 4,192 Bytes
da07a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import tensorflow as tf
from frame_slicer import extract_video_frames
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt

# Configuration
import os
MODEL_PATH = os.path.join(os.path.dirname(__file__), "trainnig_output", "final_model_2.h5")
N_FRAMES = 30
IMG_SIZE = (96, 96)
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")  # Will be created if doesn't exist

def fight_detec(video_path: str, debug: bool = True):
    """Detects fight in a video and returns the result and confidence score."""
    
    class FightDetector:
        def __init__(self):
            self.model = self._load_model()
        
        def _load_model(self):
            try:
                model = tf.keras.models.load_model(MODEL_PATH, compile=False)
                if debug:
                    print("\nModel loaded successfully. Input shape:", model.input_shape)
                return model
            except Exception as e:
                print(f"Model loading failed: {e}")
                return None
        
        def _extract_frames(self, video_path):
            frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
            if frames is None:
                return None
            
            if debug:
                blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
                if blank_frames > 0:
                    print(f"Warning: {blank_frames} blank frames detected")
                sample_frame = (frames[0] * 255).astype(np.uint8)
                os.makedirs(RESULT_PATH, exist_ok=True)
                cv2.imwrite(os.path.join(RESULT_PATH, 'debug_frame.jpg'), 
                            cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
            
            return frames
        
        def predict(self, video_path):
            if not os.path.exists(video_path):
                return "Error: Video not found", None
            
            try:
                frames = self._extract_frames(video_path)
                if frames is None:
                    return "Error: Frame extraction failed", None
                
                if frames.shape[0] != N_FRAMES:
                    return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
                
                if np.all(frames == 0):
                    return "Error: All frames are blank", None
                
                prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
                result = "FIGHT" if prediction >= 0.61 else "NORMAL"
                confidence = min(max(abs(prediction - 0.61) * 150 + 50, 0), 100)
                
                if debug:
                    self._debug_visualization(frames, prediction, result, video_path)
                
                return f"{result} ({confidence:.1f}% confidence)", prediction
            
            except Exception as e:
                return f"Prediction error: {str(e)}", None

        def _debug_visualization(self, frames, score, result, video_path):
            print(f"\nPrediction Score: {score:.4f}")
            print(f"Decision: {result}")
            plt.figure(figsize=(15, 5))
            for i in range(min(10, len(frames))):
                plt.subplot(2, 5, i+1)
                plt.imshow(frames[i])
                plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
                plt.axis('off')
            plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
            plt.tight_layout()

            # Save the visualization
            base_name = os.path.splitext(os.path.basename(video_path))[0]
            save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
            plt.savefig(save_path)
            plt.close()
            print(f"Visualization saved to: {save_path}")

    detector = FightDetector()
    if detector.model is None:
        return "Error: Model loading failed", None
    return detector.predict(video_path)

# # Entry point
# path0 = input("Enter the local path to the video file to detect fight: ")
# path = path0.strip('"')  # Remove extra quotes if copied from Windows
# print(f"[INFO] Loading video: {path}")
# fight_detec(path)