File size: 4,300 Bytes
da07a7d
 
 
adc4b03
 
 
da07a7d
 
adc4b03
 
 
 
da07a7d
 
adc4b03
 
da07a7d
adc4b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da07a7d
adc4b03
 
 
 
 
 
 
 
 
 
 
 
 
da07a7d
adc4b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da07a7d
 
 
 
 
adc4b03
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import tensorflow as tf
import os
import numpy as np
import cv2
import logging
from datetime import datetime
import matplotlib.pyplot as plt

from frame_slicer import extract_video_frames  # Make sure this module is in PYTHONPATH

# ----------------- Configuration ----------------- #
MODEL_PATH = os.path.join(os.path.dirname(__file__), "training_output", "final_model_2.h5")
N_FRAMES = 30
IMG_SIZE = (96, 96)
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
FIGHT_THRESHOLD = 0.61

os.makedirs(RESULT_PATH, exist_ok=True)

# ----------------- Logging Setup ----------------- #
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)

# ----------------- Main Detector Class ----------------- #
class FightDetector:
    def __init__(self):
        self.model = self._load_model()

    def _load_model(self):
        try:
            model = tf.keras.models.load_model(MODEL_PATH, compile=False)
            logger.info(f"Model loaded successfully. Input shape: {model.input_shape}")
            return model
        except Exception as e:
            logger.error(f"Model loading failed: {e}")
            return None

    def _extract_frames(self, video_path):
        frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
        if frames is None:
            return None

        blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
        if blank_frames > 0:
            logger.warning(f"{blank_frames} blank frames detected.")

        sample_frame = (frames[0] * 255).astype(np.uint8)
        debug_frame_path = os.path.join(RESULT_PATH, 'debug_frame.jpg')
        cv2.imwrite(debug_frame_path, cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
        logger.info(f"Debug frame saved to: {debug_frame_path}")

        return frames

    def predict(self, video_path):
        if not os.path.exists(video_path):
            return "Error: Video not found", None

        try:
            frames = self._extract_frames(video_path)
            if frames is None:
                return "Error: Frame extraction failed", None

            if frames.shape[0] != N_FRAMES:
                return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None

            if np.all(frames == 0):
                return "Error: All frames are blank", None

            prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
            result = "FIGHT" if prediction >= FIGHT_THRESHOLD else "NORMAL"
            confidence = min(max(abs(prediction - FIGHT_THRESHOLD) * 150 + 50, 0), 100)

            self._debug_visualization(frames, prediction, result, video_path)

            return f"{result} ({confidence:.1f}% confidence)", prediction

        except Exception as e:
            return f"Prediction error: {str(e)}", None

    def _debug_visualization(self, frames, score, result, video_path):
        logger.info(f"Prediction Score: {score:.4f}")
        logger.info(f"Decision: {result}")
        plt.figure(figsize=(15, 5))
        for i in range(min(10, len(frames))):
            plt.subplot(2, 5, i + 1)
            plt.imshow(frames[i])
            plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
            plt.axis('off')
        plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
        plt.tight_layout()

        base_name = os.path.splitext(os.path.basename(video_path))[0]
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        save_path = os.path.join(RESULT_PATH, f"{base_name}_result_{timestamp}.png")
        plt.savefig(save_path)
        plt.close()
        logger.info(f"Visualization saved to: {save_path}")

# ----------------- External Interface ----------------- #
def fight_detec(video_path: str, debug: bool = True):
    detector = FightDetector()
    if detector.model is None:
        return "Error: Model loading failed", None
    return detector.predict(video_path)

# ----------------- Optional CLI Entry ----------------- #
if __name__ == "__main__":
    path0 = input("Enter the local path to the video file to detect fight: ")
    path = path0.strip('"')  # Remove extra quotes if copied from Windows
    logger.info(f"Loading video: {path}")
    result, score = fight_detec(path)
    logger.info(f"Result: {result}")