|
import os
|
|
import numpy as np
|
|
import cv2
|
|
import traceback
|
|
from collections import Counter
|
|
from sklearn.model_selection import train_test_split
|
|
from tensorflow.keras.utils import Sequence
|
|
from tensorflow.keras.models import Sequential, load_model
|
|
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, Flatten, Dense, Dropout, BatchNormalization
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
|
|
import tensorflow as tf
|
|
|
|
|
|
DATA_DIR = "D:\\K_REPO\\ComV\\train"
|
|
N_FRAMES = 30
|
|
IMG_SIZE = (96, 96)
|
|
EPOCHS = 10
|
|
BATCH_SIZE = 14
|
|
CHECKPOINT_DIR = r"D:\K_REPO\ComV\AI_made\trainnig_output\checkpoint"
|
|
RESUME_TRAINING = 1
|
|
MIN_REQUIRED_FRAMES = 10
|
|
OUTPUT_PATH = r"D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5"
|
|
|
|
cv2.setUseOptimized(True)
|
|
cv2.setNumThreads(8)
|
|
|
|
|
|
class VideoDataGenerator(Sequence):
|
|
def __init__(self, video_paths, labels, batch_size, n_frames, img_size):
|
|
self.video_paths, self.labels = self._filter_invalid_videos(video_paths, labels)
|
|
self.batch_size = batch_size
|
|
self.n_frames = n_frames
|
|
self.img_size = img_size
|
|
self.indices = np.arange(len(self.video_paths))
|
|
print(f"[INFO] Final dataset size: {len(self.video_paths)} videos")
|
|
|
|
def _filter_invalid_videos(self, paths, labels):
|
|
valid_paths = []
|
|
valid_labels = []
|
|
|
|
for path, label in zip(paths, labels):
|
|
cap = cv2.VideoCapture(path)
|
|
if not cap.isOpened():
|
|
print(f"[WARNING] Could not open video: {path}")
|
|
continue
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
cap.release()
|
|
|
|
if total_frames < MIN_REQUIRED_FRAMES:
|
|
print(f"[WARNING] Skipping {path} - only {total_frames} frames (needs at least {MIN_REQUIRED_FRAMES})")
|
|
continue
|
|
|
|
valid_paths.append(path)
|
|
valid_labels.append(label)
|
|
|
|
return valid_paths, valid_labels
|
|
|
|
def __len__(self):
|
|
return int(np.ceil(len(self.video_paths) / self.batch_size))
|
|
|
|
def __getitem__(self, index):
|
|
batch_indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
|
|
X, y = [], []
|
|
|
|
for i in batch_indices:
|
|
path = self.video_paths[i]
|
|
label = self.labels[i]
|
|
try:
|
|
frames = self._load_video_frames(path)
|
|
X.append(frames)
|
|
y.append(label)
|
|
except Exception as e:
|
|
print(f"[WARNING] Error processing {path} - {str(e)}")
|
|
X.append(np.zeros((self.n_frames, *self.img_size, 3)))
|
|
y.append(label)
|
|
|
|
return np.array(X), np.array(y)
|
|
|
|
def _load_video_frames(self, path):
|
|
cap = cv2.VideoCapture(path)
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
if total_frames < self.n_frames:
|
|
frame_indices = np.linspace(0, total_frames - 1, min(total_frames, self.n_frames), dtype=np.int32)
|
|
else:
|
|
frame_indices = np.linspace(0, total_frames - 1, self.n_frames, dtype=np.int32)
|
|
|
|
frames = []
|
|
for idx in frame_indices:
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
frame = np.zeros((*self.img_size, 3), dtype=np.uint8)
|
|
else:
|
|
frame = cv2.resize(frame, self.img_size)
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
frames.append(frame)
|
|
|
|
cap.release()
|
|
|
|
while len(frames) < self.n_frames:
|
|
frames.append(frames[-1] if frames else np.zeros((*self.img_size, 3), dtype=np.uint8))
|
|
|
|
return np.array(frames) / 255.0
|
|
|
|
def on_epoch_end(self):
|
|
np.random.shuffle(self.indices)
|
|
|
|
def create_model():
|
|
model = Sequential([
|
|
Input(shape=(N_FRAMES, *IMG_SIZE, 3)),
|
|
Conv3D(32, kernel_size=(3, 3, 3), activation='relu', padding='same'),
|
|
MaxPooling3D(pool_size=(1, 2, 2)),
|
|
BatchNormalization(),
|
|
|
|
Conv3D(64, kernel_size=(3, 3, 3), activation='relu', padding='same'),
|
|
MaxPooling3D(pool_size=(1, 2, 2)),
|
|
BatchNormalization(),
|
|
|
|
Conv3D(128, kernel_size=(3, 3, 3), activation='relu', padding='same'),
|
|
MaxPooling3D(pool_size=(2, 2, 2)),
|
|
BatchNormalization(),
|
|
|
|
Flatten(),
|
|
Dense(256, activation='relu'),
|
|
Dropout(0.5),
|
|
Dense(1, activation='sigmoid')
|
|
])
|
|
|
|
model.compile(optimizer='adam',
|
|
loss='binary_crossentropy',
|
|
metrics=['accuracy'])
|
|
|
|
return model
|
|
|
|
def load_data():
|
|
video_paths, labels = [], []
|
|
for label_name in ["Fighting", "Normal"]:
|
|
label_dir = os.path.join(DATA_DIR, label_name)
|
|
if not os.path.isdir(label_dir):
|
|
raise FileNotFoundError(f"Directory not found: {label_dir}")
|
|
|
|
label = 1 if label_name.lower() == "fighting" else 0
|
|
|
|
for file in os.listdir(label_dir):
|
|
if file.lower().endswith((".mp4", ".mpeg", ".avi", ".mov")):
|
|
full_path = os.path.join(label_dir, file)
|
|
video_paths.append(full_path)
|
|
labels.append(label)
|
|
|
|
if not video_paths:
|
|
raise ValueError(f"No videos found in {DATA_DIR}")
|
|
|
|
print(f"[INFO] Total videos: {len(video_paths)} (Fighting: {labels.count(1)}, Normal: {labels.count(0)})")
|
|
|
|
if len(set(labels)) > 1:
|
|
return train_test_split(video_paths, labels, test_size=0.2, stratify=labels, random_state=42)
|
|
else:
|
|
print("[WARNING] Only one class found. Splitting without stratification.")
|
|
return train_test_split(video_paths, labels, test_size=0.2, random_state=42)
|
|
|
|
def get_latest_checkpoint():
|
|
if not os.path.exists(CHECKPOINT_DIR):
|
|
os.makedirs(CHECKPOINT_DIR)
|
|
return None
|
|
|
|
checkpoints = [f for f in os.listdir(CHECKPOINT_DIR)
|
|
if f.startswith('ckpt_') and f.endswith('.h5')]
|
|
if not checkpoints:
|
|
return None
|
|
|
|
checkpoints.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
|
|
return os.path.join(CHECKPOINT_DIR, checkpoints[-1])
|
|
|
|
def main():
|
|
|
|
try:
|
|
train_paths, val_paths, train_labels, val_labels = load_data()
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to load data: {str(e)}")
|
|
return
|
|
|
|
|
|
try:
|
|
train_gen = VideoDataGenerator(train_paths, train_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE)
|
|
val_gen = VideoDataGenerator(val_paths, val_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE)
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to create data generators: {str(e)}")
|
|
return
|
|
|
|
|
|
callbacks = [
|
|
ModelCheckpoint(
|
|
os.path.join(CHECKPOINT_DIR, 'ckpt_{epoch}.h5'),
|
|
save_best_only=False,
|
|
save_weights_only=False
|
|
),
|
|
CSVLogger('training_log.csv', append=True),
|
|
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
|
|
]
|
|
|
|
|
|
initial_epoch = 0
|
|
try:
|
|
if RESUME_TRAINING:
|
|
ckpt = get_latest_checkpoint()
|
|
if ckpt:
|
|
print(f"[INFO] Resuming training from checkpoint: {ckpt}")
|
|
model = load_model(ckpt)
|
|
initial_epoch = int(ckpt.split('_')[1].split('.')[0])
|
|
else:
|
|
print("[INFO] No checkpoint found, starting new training")
|
|
model = create_model()
|
|
else:
|
|
model = create_model()
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to initialize model: {str(e)}")
|
|
return
|
|
|
|
|
|
model.summary()
|
|
|
|
|
|
try:
|
|
print("[INFO] Starting training...")
|
|
history = model.fit(
|
|
train_gen,
|
|
validation_data=val_gen,
|
|
epochs=EPOCHS,
|
|
initial_epoch=initial_epoch,
|
|
callbacks=callbacks,
|
|
verbose=1
|
|
)
|
|
except Exception as e:
|
|
print(f"[ERROR] Training failed: {str(e)}")
|
|
traceback.print_exc()
|
|
finally:
|
|
model.save(OUTPUT_PATH)
|
|
print("[INFO] Training completed. Model saved to final_model_2.h5")
|
|
|
|
if __name__ == "__main__":
|
|
print("[INFO] Starting script...")
|
|
main()
|
|
print("[INFO] Script execution completed.")
|
|
|
|
|
|
|
|
|