Upload 9 files
Browse files- Fight_detec_func.py +103 -0
- README.md +187 -12
- app.py +35 -0
- frame_slicer.py +58 -0
- full_project.py +22 -0
- model_summary.py +10 -0
- objec_detect_yolo.py +121 -0
- requirements.txt +6 -0
- trainig.py +248 -0
Fight_detec_func.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from frame_slicer import extract_video_frames
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
# Configuration
|
9 |
+
import os
|
10 |
+
MODEL_PATH = os.path.join(os.path.dirname(__file__), "trainnig_output", "final_model_2.h5")
|
11 |
+
N_FRAMES = 30
|
12 |
+
IMG_SIZE = (96, 96)
|
13 |
+
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results") # Will be created if doesn't exist
|
14 |
+
|
15 |
+
def fight_detec(video_path: str, debug: bool = True):
|
16 |
+
"""Detects fight in a video and returns the result and confidence score."""
|
17 |
+
|
18 |
+
class FightDetector:
|
19 |
+
def __init__(self):
|
20 |
+
self.model = self._load_model()
|
21 |
+
|
22 |
+
def _load_model(self):
|
23 |
+
try:
|
24 |
+
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
|
25 |
+
if debug:
|
26 |
+
print("\nModel loaded successfully. Input shape:", model.input_shape)
|
27 |
+
return model
|
28 |
+
except Exception as e:
|
29 |
+
print(f"Model loading failed: {e}")
|
30 |
+
return None
|
31 |
+
|
32 |
+
def _extract_frames(self, video_path):
|
33 |
+
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
|
34 |
+
if frames is None:
|
35 |
+
return None
|
36 |
+
|
37 |
+
if debug:
|
38 |
+
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
|
39 |
+
if blank_frames > 0:
|
40 |
+
print(f"Warning: {blank_frames} blank frames detected")
|
41 |
+
sample_frame = (frames[0] * 255).astype(np.uint8)
|
42 |
+
os.makedirs(RESULT_PATH, exist_ok=True)
|
43 |
+
cv2.imwrite(os.path.join(RESULT_PATH, 'debug_frame.jpg'),
|
44 |
+
cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
|
45 |
+
|
46 |
+
return frames
|
47 |
+
|
48 |
+
def predict(self, video_path):
|
49 |
+
if not os.path.exists(video_path):
|
50 |
+
return "Error: Video not found", None
|
51 |
+
|
52 |
+
try:
|
53 |
+
frames = self._extract_frames(video_path)
|
54 |
+
if frames is None:
|
55 |
+
return "Error: Frame extraction failed", None
|
56 |
+
|
57 |
+
if frames.shape[0] != N_FRAMES:
|
58 |
+
return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
|
59 |
+
|
60 |
+
if np.all(frames == 0):
|
61 |
+
return "Error: All frames are blank", None
|
62 |
+
|
63 |
+
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
|
64 |
+
result = "FIGHT" if prediction >= 0.61 else "NORMAL"
|
65 |
+
confidence = min(max(abs(prediction - 0.61) * 150 + 50, 0), 100)
|
66 |
+
|
67 |
+
if debug:
|
68 |
+
self._debug_visualization(frames, prediction, result, video_path)
|
69 |
+
|
70 |
+
return f"{result} ({confidence:.1f}% confidence)", prediction
|
71 |
+
|
72 |
+
except Exception as e:
|
73 |
+
return f"Prediction error: {str(e)}", None
|
74 |
+
|
75 |
+
def _debug_visualization(self, frames, score, result, video_path):
|
76 |
+
print(f"\nPrediction Score: {score:.4f}")
|
77 |
+
print(f"Decision: {result}")
|
78 |
+
plt.figure(figsize=(15, 5))
|
79 |
+
for i in range(min(10, len(frames))):
|
80 |
+
plt.subplot(2, 5, i+1)
|
81 |
+
plt.imshow(frames[i])
|
82 |
+
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
|
83 |
+
plt.axis('off')
|
84 |
+
plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
|
85 |
+
plt.tight_layout()
|
86 |
+
|
87 |
+
# Save the visualization
|
88 |
+
base_name = os.path.splitext(os.path.basename(video_path))[0]
|
89 |
+
save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
|
90 |
+
plt.savefig(save_path)
|
91 |
+
plt.close()
|
92 |
+
print(f"Visualization saved to: {save_path}")
|
93 |
+
|
94 |
+
detector = FightDetector()
|
95 |
+
if detector.model is None:
|
96 |
+
return "Error: Model loading failed", None
|
97 |
+
return detector.predict(video_path)
|
98 |
+
|
99 |
+
# # Entry point
|
100 |
+
# path0 = input("Enter the local path to the video file to detect fight: ")
|
101 |
+
# path = path0.strip('"') # Remove extra quotes if copied from Windows
|
102 |
+
# print(f"[INFO] Loading video: {path}")
|
103 |
+
# fight_detec(path)
|
README.md
CHANGED
@@ -1,12 +1,187 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Video Analysis Project: Fight and Object Detection
|
2 |
+
|
3 |
+
## 1. Overview
|
4 |
+
|
5 |
+
This project analyzes video files to perform two main tasks:
|
6 |
+
* **Fight Detection:** Classifies video segments as containing a "FIGHT" or being "NORMAL" using a custom-trained 3D Convolutional Neural Network (CNN).
|
7 |
+
* **Object Detection:** Identifies and tracks specific objects within the video using a pre-trained YOLOv8 model.
|
8 |
+
|
9 |
+
The system processes an input video and outputs the fight classification result along with an annotated version of the video highlighting detected objects.
|
10 |
+
|
11 |
+
## 2. Features
|
12 |
+
|
13 |
+
* Dual analysis: Combines action recognition (fight detection) and object detection.
|
14 |
+
* Custom-trained model for fight detection tailored to specific data.
|
15 |
+
* Utilizes state-of-the-art YOLOv8 for object detection.
|
16 |
+
* Generates an annotated output video showing detected objects and their tracks.
|
17 |
+
* Provides confidence scores for fight detection results.
|
18 |
+
* Includes scripts for both inference (`full_project.py`) and training (`trainig.py`) the fight detection model.
|
19 |
+
|
20 |
+
## 3. Project Structure
|
21 |
+
|
22 |
+
```
|
23 |
+
ComV/
|
24 |
+
├── [Project Directory]/ # e.g., AI_made
|
25 |
+
│ ├── full_project.py # Main script for running inference
|
26 |
+
│ ├── Fight_detec_func.py # Fight detection logic and model loading
|
27 |
+
│ ├── objec_detect_yolo.py # Object detection logic using YOLOv8
|
28 |
+
│ ├── frame_slicer.py # Utility for extracting frames for fight detection
|
29 |
+
│ ├── trainig.py # Script for training the fight detection model
|
30 |
+
│ ├── README.md # This documentation file
|
31 |
+
│ └── trainnig_output/ # Directory for training artifacts
|
32 |
+
│ ├── final_model_2.h5 # Trained fight detection model (relative path)
|
33 |
+
│ └── checkpoint/ # Checkpoints saved during training (relative path)
|
34 |
+
│ └── training_log.csv # Log file for training history (relative path)
|
35 |
+
│ └── yolo/ # (Assumed location)
|
36 |
+
│ └── yolo/
|
37 |
+
│ └── best.pt # Pre-trained YOLOv8 model weights (relative path)
|
38 |
+
├── train/
|
39 |
+
│ ├── Fighting/ # Directory containing fight video examples (relative path)
|
40 |
+
│ └── Normal/ # Directory containing normal video examples (relative path)
|
41 |
+
└── try/
|
42 |
+
├── result/ # Directory where output videos are saved (relative path)
|
43 |
+
└── ... (Input video files) # Location for input videos (example)
|
44 |
+
```
|
45 |
+
|
46 |
+
*(Note: Model paths and data directories might be hardcoded in the scripts. Consider making these configurable or using relative paths.)*
|
47 |
+
|
48 |
+
## 4. Setup and Installation
|
49 |
+
|
50 |
+
**Python Version:**
|
51 |
+
|
52 |
+
* This project was developed and tested using Python 3.10.
|
53 |
+
|
54 |
+
**Dependencies:**
|
55 |
+
|
56 |
+
Based on the code imports and `pip freeze` output, the following libraries and versions were used:
|
57 |
+
|
58 |
+
* `opencv-python==4.11.0.86` (cv2)
|
59 |
+
* `numpy==1.26.4`
|
60 |
+
* `tensorflow==2.19.0` (tf)
|
61 |
+
* `ultralytics==8.3.108` (for YOLOv8)
|
62 |
+
* `matplotlib==3.10.1` (for debug visualizations)
|
63 |
+
* `scikit-learn==1.6.1` (sklearn - used in `trainig.py`)
|
64 |
+
|
65 |
+
*(Note: Other versions might also work, but these are the ones confirmed in the development environment.)*
|
66 |
+
|
67 |
+
**Installation (using pip):**
|
68 |
+
|
69 |
+
```bash
|
70 |
+
pip install opencv-python numpy tensorflow ultralytics matplotlib scikit-learn
|
71 |
+
```
|
72 |
+
|
73 |
+
**Models:**
|
74 |
+
|
75 |
+
1. **Fight Detection Model:** Ensure the trained model (`final_model_2.h5`) is present in the `trainnig_output` subdirectory relative to the script location.
|
76 |
+
2. **YOLOv8 Model:** Ensure the YOLO model (`best.pt`) is present in the `yolo/yolo` subdirectory relative to the script location.
|
77 |
+
|
78 |
+
*(Note: Absolute paths might be hardcoded in the scripts and may need adjustment depending on the deployment environment.)*
|
79 |
+
|
80 |
+
## 5. Usage
|
81 |
+
|
82 |
+
To run the analysis on a video file:
|
83 |
+
|
84 |
+
1. Navigate to the `d:/K_REPO/ComV/AI_made/` directory in your terminal (or ensure Python's working directory is `d:/K_REPO`).
|
85 |
+
2. Run the main script:
|
86 |
+
```bash
|
87 |
+
python full_project.py
|
88 |
+
```
|
89 |
+
3. The script will prompt you to enter the path to the video file:
|
90 |
+
```
|
91 |
+
Enter the local path : <your_video_path.mp4>
|
92 |
+
```
|
93 |
+
*(Ensure you provide the full path, potentially removing extra quotes if copying from Windows Explorer.)*
|
94 |
+
|
95 |
+
**Output:**
|
96 |
+
|
97 |
+
* The console will print the fight detection result (e.g., "FIGHT (85.3% confidence)") and information about the object detection process.
|
98 |
+
* An annotated video file will be saved in the `D:\K_REPO\ComV\try\result` directory. The filename will include the original video name and the unique detected object labels (e.g., `input_video_label1_label2_output.mp4`).
|
99 |
+
* If debug mode is enabled in `Fight_detec_func.py`, additional debug images might be saved in the result directory.
|
100 |
+
|
101 |
+
## 6. Module Descriptions
|
102 |
+
|
103 |
+
* **`full_project.py`:** Orchestrates the process by taking user input and calling the fight detection and object detection functions.
|
104 |
+
* **`Fight_detec_func.py`:**
|
105 |
+
* Contains the `fight_detec` function and `FightDetector` class.
|
106 |
+
* Loads the Keras model (`final_model_2.h5`).
|
107 |
+
* Uses `frame_slicer` to prepare input for the model.
|
108 |
+
* Performs prediction and calculates confidence.
|
109 |
+
* Handles debug visualizations.
|
110 |
+
* **`objec_detect_yolo.py`:**
|
111 |
+
* Contains the `detection` function.
|
112 |
+
* Loads the YOLOv8 model (`best.pt`).
|
113 |
+
* Iterates through video frames, performs object detection and tracking.
|
114 |
+
* Generates and saves the annotated output video.
|
115 |
+
* Returns detected object labels.
|
116 |
+
* **`frame_slicer.py`:**
|
117 |
+
* Contains the `extract_video_frames` utility function.
|
118 |
+
* Extracts a fixed number of frames, resizes, normalizes, and handles potential errors during extraction.
|
119 |
+
* **`trainig.py`:**
|
120 |
+
* Script for training the fight detection model.
|
121 |
+
* Includes `VideoDataGenerator` for loading/processing video data.
|
122 |
+
* Defines the 3D CNN model architecture.
|
123 |
+
* Handles data loading, splitting, training loops, checkpointing, and saving the final model.
|
124 |
+
|
125 |
+
## 7. Training Data
|
126 |
+
|
127 |
+
### Dataset Composition
|
128 |
+
| Category | Count | Percentage | Formats | Avg Duration |
|
129 |
+
|----------------|-------|------------|---------------|--------------|
|
130 |
+
| Fight Videos | 2,340 | 61.9% | .mp4, .mpeg | 5.2 sec |
|
131 |
+
| Normal Videos | 1,441 | 38.1% | .mp4, .mpeg | 6.1 sec |
|
132 |
+
| **Total** | **3,781** | **100%** | | |
|
133 |
+
|
134 |
+
### Technical Specifications
|
135 |
+
- **Resolution:** 64×64 pixels
|
136 |
+
- **Color Space:** RGB
|
137 |
+
- **Frame Rate:** 30 FPS (average)
|
138 |
+
- **Frame Sampling:** 50 frames per video
|
139 |
+
- **Input Shape:** (30, 96, 96, 3) - Model resizes input
|
140 |
+
|
141 |
+
### Data Sources
|
142 |
+
- Fighting videos: Collected from public surveillance datasets
|
143 |
+
- Normal videos: Sampled from public CCTV footage
|
144 |
+
- Manually verified and labeled by domain experts
|
145 |
+
|
146 |
+
### Preprocessing
|
147 |
+
1. Frame extraction at 50 frames/video
|
148 |
+
2. Resizing to 96×96 pixels
|
149 |
+
3. Normalization (pixel values [0,1])
|
150 |
+
4. Temporal sampling to 30 frames for model input
|
151 |
+
|
152 |
+
## 8. Models Used
|
153 |
+
|
154 |
+
* **Fight Detection:** A custom 3D CNN trained using `trainig.py`. Located at `D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5`. Input shape expects `(30, 96, 96, 3)` frames.
|
155 |
+
* **Object Detection:** YOLOv8 model. Weights located at `D:\K_REPO\ComV\yolo\yolo\best.pt`. This model is trained to detect the following classes: `['Fire', 'Gun', 'License_Plate', 'Smoke', 'knife']`.
|
156 |
+
|
157 |
+
## 7a. Fight Detection Model Performance
|
158 |
+
|
159 |
+
The following metrics represent the performance achieved during the training of the `final_model_2.h5`:
|
160 |
+
|
161 |
+
* **Best Training Accuracy:** 0.8583 (Epoch 7)
|
162 |
+
* **Best Validation Accuracy:** 0.9167 (Epoch 10)
|
163 |
+
* **Lowest Training Loss:** 0.3636 (Epoch 7)
|
164 |
+
* **Lowest Validation Loss:** 0.2805 (Epoch 8)
|
165 |
+
|
166 |
+
*(Note: These metrics are based on the training run that produced the saved model. Performance may vary slightly on different datasets or during retraining.)*
|
167 |
+
|
168 |
+
## 8. Configuration
|
169 |
+
|
170 |
+
Key parameters and paths are mostly hardcoded within the scripts:
|
171 |
+
|
172 |
+
* `Fight_detec_func.py`: `MODEL_PATH`, `N_FRAMES`, `IMG_SIZE`, `RESULT_PATH`.
|
173 |
+
* `objec_detect_yolo.py`: YOLO model path, output directory path (`output_dir`), confidence threshold (`conf=0.7`).
|
174 |
+
* `trainig.py`: `DATA_DIR`, `N_FRAMES`, `IMG_SIZE`, `EPOCHS`, `BATCH_SIZE`, `CHECKPOINT_DIR`, `OUTPUT_PATH`.
|
175 |
+
|
176 |
+
*Recommendation: Refactor these hardcoded values into a separate configuration file (e.g., YAML or JSON) or use command-line arguments for better flexibility.*
|
177 |
+
|
178 |
+
## 9. Training the Fight Detection Model
|
179 |
+
|
180 |
+
To retrain or train the fight detection model:
|
181 |
+
|
182 |
+
1. **Prepare Data:** Place training videos into `D:\K_REPO\ComV\train\Fighting` and `D:\K_REPO\ComV\train\Normal` subdirectories.
|
183 |
+
2. **Run Training Script:** Execute `trainig.py`:
|
184 |
+
```bash
|
185 |
+
python trainig.py
|
186 |
+
```
|
187 |
+
3. The script will load data, build the model (or resume from a checkpoint if `RESUME_TRAINING=1` and a checkpoint exists), train it, and save the final model to `D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5`. Checkpoints and logs are saved in the `trainnig_output` directory.
|
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
from Fight_detec_func import fight_detec
|
5 |
+
from objec_detect_yolo import detection
|
6 |
+
|
7 |
+
def analyze_video(video_file):
|
8 |
+
# Save uploaded file to temp location
|
9 |
+
temp_dir = tempfile.mkdtemp()
|
10 |
+
video_path = os.path.join(temp_dir, video_file.name)
|
11 |
+
with open(video_path, 'wb') as f:
|
12 |
+
f.write(video_file.read())
|
13 |
+
|
14 |
+
# Run both detection functions
|
15 |
+
fight_result = fight_detec(video_path, debug=False)
|
16 |
+
yolo_result = detection(video_path)
|
17 |
+
|
18 |
+
# Clean up
|
19 |
+
os.remove(video_path)
|
20 |
+
os.rmdir(temp_dir)
|
21 |
+
|
22 |
+
return {
|
23 |
+
"Fight Detection": fight_result[0],
|
24 |
+
"YOLO Object Detection": yolo_result
|
25 |
+
}
|
26 |
+
|
27 |
+
iface = gr.Interface(
|
28 |
+
fn=analyze_video,
|
29 |
+
inputs=gr.Video(label="Upload Video"),
|
30 |
+
outputs=gr.JSON(label="Detection Results"),
|
31 |
+
title="Fight and Object Detection System",
|
32 |
+
description="Upload a video to detect fights and objects using our AI models"
|
33 |
+
)
|
34 |
+
|
35 |
+
iface.launch()
|
frame_slicer.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
|
5 |
+
def extract_video_frames(video_path, n_frames=30, frame_size=(96, 96)):
|
6 |
+
"""
|
7 |
+
Simplified robust frame extractor for short videos (2-10 sec)
|
8 |
+
- Automatically handles varying video lengths
|
9 |
+
- Ensures consistent output shape
|
10 |
+
- Optimized for MP4/MPEG
|
11 |
+
"""
|
12 |
+
# Open video
|
13 |
+
cap = cv2.VideoCapture(video_path)
|
14 |
+
if not cap.isOpened():
|
15 |
+
print(f"Error: Could not open video {video_path}")
|
16 |
+
return None
|
17 |
+
|
18 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
19 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
20 |
+
|
21 |
+
# Basic video validation
|
22 |
+
if total_frames < 1 or fps < 1:
|
23 |
+
print(f"Error: Invalid video (frames:{total_frames}, fps:{fps})")
|
24 |
+
cap.release()
|
25 |
+
return None
|
26 |
+
|
27 |
+
# Calculate how many frames to skip (adaptive based on video length)
|
28 |
+
video_length = total_frames / fps
|
29 |
+
frame_step = max(1, int(total_frames / n_frames))
|
30 |
+
|
31 |
+
frames = []
|
32 |
+
last_good_frame = None
|
33 |
+
|
34 |
+
for i in range(n_frames):
|
35 |
+
# Calculate position to read (spread evenly across video)
|
36 |
+
pos = min(int(i * (total_frames / n_frames)), total_frames - 1)
|
37 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, pos)
|
38 |
+
|
39 |
+
ret, frame = cap.read()
|
40 |
+
|
41 |
+
# Fallback strategies if read fails
|
42 |
+
if not ret or frame is None:
|
43 |
+
if last_good_frame is not None:
|
44 |
+
frame = last_good_frame.copy()
|
45 |
+
else:
|
46 |
+
# Generate placeholder frame (light gray)
|
47 |
+
frame = np.full((*frame_size[::-1], 3), 0.8, dtype=np.float32)
|
48 |
+
else:
|
49 |
+
# Process valid frame
|
50 |
+
frame = cv2.resize(frame, frame_size)
|
51 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
52 |
+
frame = frame.astype(np.float32) / 255.0
|
53 |
+
last_good_frame = frame
|
54 |
+
|
55 |
+
frames.append(frame)
|
56 |
+
|
57 |
+
cap.release()
|
58 |
+
return np.array(frames)
|
full_project.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
from ultralytics import YOLO
|
5 |
+
import time
|
6 |
+
import tensorflow as tf
|
7 |
+
from frame_slicer import extract_video_frames
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from Fight_detec_func import fight_detec
|
11 |
+
from objec_detect_yolo import detection
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Entry point
|
16 |
+
path0 = input("Enter the local path : ")
|
17 |
+
path = path0.strip('"') # Remove extra quotes if copied from Windows
|
18 |
+
print(f"[INFO] Loading video: {path}")
|
19 |
+
|
20 |
+
fight_detec(path)
|
21 |
+
detection(path)
|
22 |
+
|
model_summary.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow.keras.models import load_model
|
2 |
+
|
3 |
+
model = load_model(r"D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5")
|
4 |
+
model.summary()
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
from tensorflow.python.client import device_lib
|
9 |
+
print("[INFO] Devices available:")
|
10 |
+
print(device_lib.list_local_devices())
|
objec_detect_yolo.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
from ultralytics import YOLO
|
5 |
+
import time
|
6 |
+
from typing import Tuple, Set
|
7 |
+
|
8 |
+
def detection(path: str) -> Tuple[Set[str], str]:
|
9 |
+
"""
|
10 |
+
Detects and tracks objects in a video using YOLOv8 model, saving an annotated output video.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
path (str): Path to the input video file. Supports common video formats (mp4, avi, etc.)
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
Tuple[Set[str], str]:
|
17 |
+
- Set of unique detected object labels (e.g., {'Gun', 'Knife'})
|
18 |
+
- Path to the output annotated video with detection boxes and tracking IDs
|
19 |
+
|
20 |
+
Raises:
|
21 |
+
FileNotFoundError: If input video doesn't exist
|
22 |
+
ValueError: If video cannot be opened/processed
|
23 |
+
"""
|
24 |
+
|
25 |
+
# Validate input file exists
|
26 |
+
if not os.path.exists(path):
|
27 |
+
raise FileNotFoundError(f"Video file not found: {path}")
|
28 |
+
|
29 |
+
# Initialize YOLOv8 model with pretrained weights
|
30 |
+
# Model is trained to detect: ['Fire', 'Gun', 'License_Plate', 'Smoke', 'knife']
|
31 |
+
model = YOLO(os.path.join(os.path.dirname(__file__), "yolo", "best.pt"))
|
32 |
+
class_names = model.names # Get class label mappings
|
33 |
+
|
34 |
+
# Set up output paths:
|
35 |
+
# 1. Temporary output during processing
|
36 |
+
# 2. Final output with detected objects in filename
|
37 |
+
input_video_name = os.path.basename(path)
|
38 |
+
base_name = os.path.splitext(input_video_name)[0]
|
39 |
+
temp_output_name = f"{base_name}_output_temp.mp4"
|
40 |
+
output_dir = "results"
|
41 |
+
os.makedirs(output_dir, exist_ok=True) # Create output dir if needed
|
42 |
+
if not os.path.exists(output_dir):
|
43 |
+
raise ValueError(f"Failed to create output directory: {output_dir}")
|
44 |
+
temp_output_path = os.path.join(output_dir, temp_output_name)
|
45 |
+
|
46 |
+
# Video processing setup:
|
47 |
+
# - Open input video stream
|
48 |
+
# - Initialize output writer with MP4 codec
|
49 |
+
cap = cv2.VideoCapture(path)
|
50 |
+
if not cap.isOpened():
|
51 |
+
raise ValueError(f"Failed to open video file: {path}")
|
52 |
+
|
53 |
+
# Process all frames at 640x640 resolution for consistency
|
54 |
+
frame_width, frame_height = 640, 640
|
55 |
+
out = cv2.VideoWriter(
|
56 |
+
temp_output_path,
|
57 |
+
cv2.VideoWriter_fourcc(*'mp4v'), # MP4 codec
|
58 |
+
30.0, # Output FPS
|
59 |
+
(frame_width, frame_height)
|
60 |
+
)
|
61 |
+
|
62 |
+
# Main processing loop:
|
63 |
+
# 1. Read each frame
|
64 |
+
# 2. Run object detection + tracking
|
65 |
+
# 3. Annotate frame with boxes and IDs
|
66 |
+
# 4. Collect detected classes
|
67 |
+
crimes = [] # Track all detected objects
|
68 |
+
start = time.time()
|
69 |
+
print(f"[INFO] Processing started at {start:.2f} seconds")
|
70 |
+
|
71 |
+
while True:
|
72 |
+
ret, frame = cap.read()
|
73 |
+
if not ret: # End of video
|
74 |
+
break
|
75 |
+
|
76 |
+
# Resize and run detection + tracking
|
77 |
+
frame = cv2.resize(frame, (frame_width, frame_height))
|
78 |
+
results = model.track(
|
79 |
+
source=frame,
|
80 |
+
conf=0.7, # Minimum confidence threshold
|
81 |
+
persist=True # Enable tracking across frames
|
82 |
+
)
|
83 |
+
|
84 |
+
# Annotate frame with boxes and tracking IDs
|
85 |
+
annotated_frame = results[0].plot()
|
86 |
+
|
87 |
+
# Record detected classes
|
88 |
+
for box in results[0].boxes:
|
89 |
+
cls = int(box.cls)
|
90 |
+
crimes.append(class_names[cls])
|
91 |
+
|
92 |
+
out.write(annotated_frame)
|
93 |
+
|
94 |
+
# Clean up video resources
|
95 |
+
end = time.time()
|
96 |
+
print(f"[INFO] Processing finished at {end:.2f} seconds")
|
97 |
+
print(f"[INFO] Total execution time: {end - start:.2f} seconds")
|
98 |
+
cap.release()
|
99 |
+
out.release()
|
100 |
+
|
101 |
+
# Generate final output filename containing detected object labels
|
102 |
+
# Format: {original_name}_{detected_objects}_output.mp4
|
103 |
+
unique_crimes = set(crimes)
|
104 |
+
crimes_str = "_".join(sorted(unique_crimes)).replace(" ", "_")[:50] # truncate if needed
|
105 |
+
final_output_name = f"{base_name}_{crimes_str}_output.mp4"
|
106 |
+
final_output_path = os.path.join(output_dir, final_output_name)
|
107 |
+
|
108 |
+
# Rename the video file
|
109 |
+
os.rename(temp_output_path, final_output_path)
|
110 |
+
|
111 |
+
print(f"[INFO] Detected crimes: {unique_crimes}")
|
112 |
+
print(f"[INFO] Annotated video saved at: {final_output_path}")
|
113 |
+
|
114 |
+
return unique_crimes, final_output_path
|
115 |
+
|
116 |
+
|
117 |
+
# # Entry point
|
118 |
+
# path0 = input("Enter the local path to the video file to detect objects: ")
|
119 |
+
# path = path0.strip('"') # Remove extra quotes if copied from Windows
|
120 |
+
# print(f"[INFO] Loading video: {path}")
|
121 |
+
# detection(path)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=3.0
|
2 |
+
tensorflow>=2.10
|
3 |
+
opencv-python>=4.6
|
4 |
+
ultralytics>=8.0
|
5 |
+
numpy>=1.22
|
6 |
+
matplotlib>=3.6
|
trainig.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import traceback
|
5 |
+
from collections import Counter
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
from tensorflow.keras.utils import Sequence
|
8 |
+
from tensorflow.keras.models import Sequential, load_model
|
9 |
+
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, Flatten, Dense, Dropout, BatchNormalization
|
10 |
+
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
|
11 |
+
import tensorflow as tf
|
12 |
+
|
13 |
+
# === CONFIG ===
|
14 |
+
DATA_DIR = "D:\\K_REPO\\ComV\\train"
|
15 |
+
N_FRAMES = 30
|
16 |
+
IMG_SIZE = (96, 96)
|
17 |
+
EPOCHS = 10
|
18 |
+
BATCH_SIZE = 14
|
19 |
+
CHECKPOINT_DIR = r"D:\K_REPO\ComV\AI_made\trainnig_output\checkpoint"
|
20 |
+
RESUME_TRAINING = 1
|
21 |
+
MIN_REQUIRED_FRAMES = 10
|
22 |
+
OUTPUT_PATH = r"D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5"
|
23 |
+
# Optimize OpenCV
|
24 |
+
cv2.setUseOptimized(True)
|
25 |
+
cv2.setNumThreads(8)
|
26 |
+
|
27 |
+
# === VIDEO DATA GENERATOR ===
|
28 |
+
class VideoDataGenerator(Sequence):
|
29 |
+
def __init__(self, video_paths, labels, batch_size, n_frames, img_size):
|
30 |
+
self.video_paths, self.labels = self._filter_invalid_videos(video_paths, labels)
|
31 |
+
self.batch_size = batch_size
|
32 |
+
self.n_frames = n_frames
|
33 |
+
self.img_size = img_size
|
34 |
+
self.indices = np.arange(len(self.video_paths))
|
35 |
+
print(f"[INFO] Final dataset size: {len(self.video_paths)} videos")
|
36 |
+
|
37 |
+
def _filter_invalid_videos(self, paths, labels):
|
38 |
+
valid_paths = []
|
39 |
+
valid_labels = []
|
40 |
+
|
41 |
+
for path, label in zip(paths, labels):
|
42 |
+
cap = cv2.VideoCapture(path)
|
43 |
+
if not cap.isOpened():
|
44 |
+
print(f"[WARNING] Could not open video: {path}")
|
45 |
+
continue
|
46 |
+
|
47 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
48 |
+
cap.release()
|
49 |
+
|
50 |
+
if total_frames < MIN_REQUIRED_FRAMES:
|
51 |
+
print(f"[WARNING] Skipping {path} - only {total_frames} frames (needs at least {MIN_REQUIRED_FRAMES})")
|
52 |
+
continue
|
53 |
+
|
54 |
+
valid_paths.append(path)
|
55 |
+
valid_labels.append(label)
|
56 |
+
|
57 |
+
return valid_paths, valid_labels
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return int(np.ceil(len(self.video_paths) / self.batch_size))
|
61 |
+
|
62 |
+
def __getitem__(self, index):
|
63 |
+
batch_indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
|
64 |
+
X, y = [], []
|
65 |
+
|
66 |
+
for i in batch_indices:
|
67 |
+
path = self.video_paths[i]
|
68 |
+
label = self.labels[i]
|
69 |
+
try:
|
70 |
+
frames = self._load_video_frames(path)
|
71 |
+
X.append(frames)
|
72 |
+
y.append(label)
|
73 |
+
except Exception as e:
|
74 |
+
print(f"[WARNING] Error processing {path} - {str(e)}")
|
75 |
+
X.append(np.zeros((self.n_frames, *self.img_size, 3)))
|
76 |
+
y.append(label)
|
77 |
+
|
78 |
+
return np.array(X), np.array(y)
|
79 |
+
|
80 |
+
def _load_video_frames(self, path):
|
81 |
+
cap = cv2.VideoCapture(path)
|
82 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
83 |
+
|
84 |
+
if total_frames < self.n_frames:
|
85 |
+
frame_indices = np.linspace(0, total_frames - 1, min(total_frames, self.n_frames), dtype=np.int32)
|
86 |
+
else:
|
87 |
+
frame_indices = np.linspace(0, total_frames - 1, self.n_frames, dtype=np.int32)
|
88 |
+
|
89 |
+
frames = []
|
90 |
+
for idx in frame_indices:
|
91 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
92 |
+
ret, frame = cap.read()
|
93 |
+
if not ret:
|
94 |
+
frame = np.zeros((*self.img_size, 3), dtype=np.uint8)
|
95 |
+
else:
|
96 |
+
frame = cv2.resize(frame, self.img_size)
|
97 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
98 |
+
frames.append(frame)
|
99 |
+
|
100 |
+
cap.release()
|
101 |
+
|
102 |
+
while len(frames) < self.n_frames:
|
103 |
+
frames.append(frames[-1] if frames else np.zeros((*self.img_size, 3), dtype=np.uint8))
|
104 |
+
|
105 |
+
return np.array(frames) / 255.0
|
106 |
+
|
107 |
+
def on_epoch_end(self):
|
108 |
+
np.random.shuffle(self.indices)
|
109 |
+
|
110 |
+
def create_model():
|
111 |
+
model = Sequential([
|
112 |
+
Input(shape=(N_FRAMES, *IMG_SIZE, 3)),
|
113 |
+
Conv3D(32, kernel_size=(3, 3, 3), activation='relu', padding='same'),
|
114 |
+
MaxPooling3D(pool_size=(1, 2, 2)),
|
115 |
+
BatchNormalization(),
|
116 |
+
|
117 |
+
Conv3D(64, kernel_size=(3, 3, 3), activation='relu', padding='same'),
|
118 |
+
MaxPooling3D(pool_size=(1, 2, 2)),
|
119 |
+
BatchNormalization(),
|
120 |
+
|
121 |
+
Conv3D(128, kernel_size=(3, 3, 3), activation='relu', padding='same'),
|
122 |
+
MaxPooling3D(pool_size=(2, 2, 2)),
|
123 |
+
BatchNormalization(),
|
124 |
+
|
125 |
+
Flatten(),
|
126 |
+
Dense(256, activation='relu'),
|
127 |
+
Dropout(0.5),
|
128 |
+
Dense(1, activation='sigmoid')
|
129 |
+
])
|
130 |
+
|
131 |
+
model.compile(optimizer='adam',
|
132 |
+
loss='binary_crossentropy',
|
133 |
+
metrics=['accuracy'])
|
134 |
+
|
135 |
+
return model
|
136 |
+
|
137 |
+
def load_data():
|
138 |
+
video_paths, labels = [], []
|
139 |
+
for label_name in ["Fighting", "Normal"]:
|
140 |
+
label_dir = os.path.join(DATA_DIR, label_name)
|
141 |
+
if not os.path.isdir(label_dir):
|
142 |
+
raise FileNotFoundError(f"Directory not found: {label_dir}")
|
143 |
+
|
144 |
+
label = 1 if label_name.lower() == "fighting" else 0
|
145 |
+
|
146 |
+
for file in os.listdir(label_dir):
|
147 |
+
if file.lower().endswith((".mp4", ".mpeg", ".avi", ".mov")):
|
148 |
+
full_path = os.path.join(label_dir, file)
|
149 |
+
video_paths.append(full_path)
|
150 |
+
labels.append(label)
|
151 |
+
|
152 |
+
if not video_paths:
|
153 |
+
raise ValueError(f"No videos found in {DATA_DIR}")
|
154 |
+
|
155 |
+
print(f"[INFO] Total videos: {len(video_paths)} (Fighting: {labels.count(1)}, Normal: {labels.count(0)})")
|
156 |
+
|
157 |
+
if len(set(labels)) > 1:
|
158 |
+
return train_test_split(video_paths, labels, test_size=0.2, stratify=labels, random_state=42)
|
159 |
+
else:
|
160 |
+
print("[WARNING] Only one class found. Splitting without stratification.")
|
161 |
+
return train_test_split(video_paths, labels, test_size=0.2, random_state=42)
|
162 |
+
|
163 |
+
def get_latest_checkpoint():
|
164 |
+
if not os.path.exists(CHECKPOINT_DIR):
|
165 |
+
os.makedirs(CHECKPOINT_DIR)
|
166 |
+
return None
|
167 |
+
|
168 |
+
checkpoints = [f for f in os.listdir(CHECKPOINT_DIR)
|
169 |
+
if f.startswith('ckpt_') and f.endswith('.h5')]
|
170 |
+
if not checkpoints:
|
171 |
+
return None
|
172 |
+
|
173 |
+
checkpoints.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
|
174 |
+
return os.path.join(CHECKPOINT_DIR, checkpoints[-1])
|
175 |
+
|
176 |
+
def main():
|
177 |
+
# Load and split data
|
178 |
+
try:
|
179 |
+
train_paths, val_paths, train_labels, val_labels = load_data()
|
180 |
+
except Exception as e:
|
181 |
+
print(f"[ERROR] Failed to load data: {str(e)}")
|
182 |
+
return
|
183 |
+
|
184 |
+
# Create data generators
|
185 |
+
try:
|
186 |
+
train_gen = VideoDataGenerator(train_paths, train_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE)
|
187 |
+
val_gen = VideoDataGenerator(val_paths, val_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE)
|
188 |
+
except Exception as e:
|
189 |
+
print(f"[ERROR] Failed to create data generators: {str(e)}")
|
190 |
+
return
|
191 |
+
|
192 |
+
# Callbacks
|
193 |
+
callbacks = [
|
194 |
+
ModelCheckpoint(
|
195 |
+
os.path.join(CHECKPOINT_DIR, 'ckpt_{epoch}.h5'),
|
196 |
+
save_best_only=False,
|
197 |
+
save_weights_only=False
|
198 |
+
),
|
199 |
+
CSVLogger('training_log.csv', append=True),
|
200 |
+
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
|
201 |
+
]
|
202 |
+
|
203 |
+
# Handle resume training
|
204 |
+
initial_epoch = 0
|
205 |
+
try:
|
206 |
+
if RESUME_TRAINING:
|
207 |
+
ckpt = get_latest_checkpoint()
|
208 |
+
if ckpt:
|
209 |
+
print(f"[INFO] Resuming training from checkpoint: {ckpt}")
|
210 |
+
model = load_model(ckpt)
|
211 |
+
initial_epoch = int(ckpt.split('_')[1].split('.')[0])
|
212 |
+
else:
|
213 |
+
print("[INFO] No checkpoint found, starting new training")
|
214 |
+
model = create_model()
|
215 |
+
else:
|
216 |
+
model = create_model()
|
217 |
+
except Exception as e:
|
218 |
+
print(f"[ERROR] Failed to initialize model: {str(e)}")
|
219 |
+
return
|
220 |
+
|
221 |
+
# Display model summary
|
222 |
+
model.summary()
|
223 |
+
|
224 |
+
# Train model
|
225 |
+
try:
|
226 |
+
print("[INFO] Starting training...")
|
227 |
+
history = model.fit(
|
228 |
+
train_gen,
|
229 |
+
validation_data=val_gen,
|
230 |
+
epochs=EPOCHS,
|
231 |
+
initial_epoch=initial_epoch,
|
232 |
+
callbacks=callbacks,
|
233 |
+
verbose=1
|
234 |
+
)
|
235 |
+
except Exception as e:
|
236 |
+
print(f"[ERROR] Training failed: {str(e)}")
|
237 |
+
traceback.print_exc()
|
238 |
+
finally:
|
239 |
+
model.save(OUTPUT_PATH)
|
240 |
+
print("[INFO] Training completed. Model saved to final_model_2.h5")
|
241 |
+
|
242 |
+
if __name__ == "__main__":
|
243 |
+
print("[INFO] Starting script...")
|
244 |
+
main()
|
245 |
+
print("[INFO] Script execution completed.")
|
246 |
+
|
247 |
+
|
248 |
+
|