KillD00zer's picture
Upload 9 files
da07a7d verified
import os
import numpy as np
import cv2
import traceback
from collections import Counter
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
import tensorflow as tf
# === CONFIG ===
DATA_DIR = "D:\\K_REPO\\ComV\\train"
N_FRAMES = 30
IMG_SIZE = (96, 96)
EPOCHS = 10
BATCH_SIZE = 14
CHECKPOINT_DIR = r"D:\K_REPO\ComV\AI_made\trainnig_output\checkpoint"
RESUME_TRAINING = 1
MIN_REQUIRED_FRAMES = 10
OUTPUT_PATH = r"D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5"
# Optimize OpenCV
cv2.setUseOptimized(True)
cv2.setNumThreads(8)
# === VIDEO DATA GENERATOR ===
class VideoDataGenerator(Sequence):
def __init__(self, video_paths, labels, batch_size, n_frames, img_size):
self.video_paths, self.labels = self._filter_invalid_videos(video_paths, labels)
self.batch_size = batch_size
self.n_frames = n_frames
self.img_size = img_size
self.indices = np.arange(len(self.video_paths))
print(f"[INFO] Final dataset size: {len(self.video_paths)} videos")
def _filter_invalid_videos(self, paths, labels):
valid_paths = []
valid_labels = []
for path, label in zip(paths, labels):
cap = cv2.VideoCapture(path)
if not cap.isOpened():
print(f"[WARNING] Could not open video: {path}")
continue
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
if total_frames < MIN_REQUIRED_FRAMES:
print(f"[WARNING] Skipping {path} - only {total_frames} frames (needs at least {MIN_REQUIRED_FRAMES})")
continue
valid_paths.append(path)
valid_labels.append(label)
return valid_paths, valid_labels
def __len__(self):
return int(np.ceil(len(self.video_paths) / self.batch_size))
def __getitem__(self, index):
batch_indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
X, y = [], []
for i in batch_indices:
path = self.video_paths[i]
label = self.labels[i]
try:
frames = self._load_video_frames(path)
X.append(frames)
y.append(label)
except Exception as e:
print(f"[WARNING] Error processing {path} - {str(e)}")
X.append(np.zeros((self.n_frames, *self.img_size, 3)))
y.append(label)
return np.array(X), np.array(y)
def _load_video_frames(self, path):
cap = cv2.VideoCapture(path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames < self.n_frames:
frame_indices = np.linspace(0, total_frames - 1, min(total_frames, self.n_frames), dtype=np.int32)
else:
frame_indices = np.linspace(0, total_frames - 1, self.n_frames, dtype=np.int32)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if not ret:
frame = np.zeros((*self.img_size, 3), dtype=np.uint8)
else:
frame = cv2.resize(frame, self.img_size)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
while len(frames) < self.n_frames:
frames.append(frames[-1] if frames else np.zeros((*self.img_size, 3), dtype=np.uint8))
return np.array(frames) / 255.0
def on_epoch_end(self):
np.random.shuffle(self.indices)
def create_model():
model = Sequential([
Input(shape=(N_FRAMES, *IMG_SIZE, 3)),
Conv3D(32, kernel_size=(3, 3, 3), activation='relu', padding='same'),
MaxPooling3D(pool_size=(1, 2, 2)),
BatchNormalization(),
Conv3D(64, kernel_size=(3, 3, 3), activation='relu', padding='same'),
MaxPooling3D(pool_size=(1, 2, 2)),
BatchNormalization(),
Conv3D(128, kernel_size=(3, 3, 3), activation='relu', padding='same'),
MaxPooling3D(pool_size=(2, 2, 2)),
BatchNormalization(),
Flatten(),
Dense(256, activation='relu'),
Dropout(0.5),
Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
return model
def load_data():
video_paths, labels = [], []
for label_name in ["Fighting", "Normal"]:
label_dir = os.path.join(DATA_DIR, label_name)
if not os.path.isdir(label_dir):
raise FileNotFoundError(f"Directory not found: {label_dir}")
label = 1 if label_name.lower() == "fighting" else 0
for file in os.listdir(label_dir):
if file.lower().endswith((".mp4", ".mpeg", ".avi", ".mov")):
full_path = os.path.join(label_dir, file)
video_paths.append(full_path)
labels.append(label)
if not video_paths:
raise ValueError(f"No videos found in {DATA_DIR}")
print(f"[INFO] Total videos: {len(video_paths)} (Fighting: {labels.count(1)}, Normal: {labels.count(0)})")
if len(set(labels)) > 1:
return train_test_split(video_paths, labels, test_size=0.2, stratify=labels, random_state=42)
else:
print("[WARNING] Only one class found. Splitting without stratification.")
return train_test_split(video_paths, labels, test_size=0.2, random_state=42)
def get_latest_checkpoint():
if not os.path.exists(CHECKPOINT_DIR):
os.makedirs(CHECKPOINT_DIR)
return None
checkpoints = [f for f in os.listdir(CHECKPOINT_DIR)
if f.startswith('ckpt_') and f.endswith('.h5')]
if not checkpoints:
return None
checkpoints.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
return os.path.join(CHECKPOINT_DIR, checkpoints[-1])
def main():
# Load and split data
try:
train_paths, val_paths, train_labels, val_labels = load_data()
except Exception as e:
print(f"[ERROR] Failed to load data: {str(e)}")
return
# Create data generators
try:
train_gen = VideoDataGenerator(train_paths, train_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE)
val_gen = VideoDataGenerator(val_paths, val_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE)
except Exception as e:
print(f"[ERROR] Failed to create data generators: {str(e)}")
return
# Callbacks
callbacks = [
ModelCheckpoint(
os.path.join(CHECKPOINT_DIR, 'ckpt_{epoch}.h5'),
save_best_only=False,
save_weights_only=False
),
CSVLogger('training_log.csv', append=True),
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
]
# Handle resume training
initial_epoch = 0
try:
if RESUME_TRAINING:
ckpt = get_latest_checkpoint()
if ckpt:
print(f"[INFO] Resuming training from checkpoint: {ckpt}")
model = load_model(ckpt)
initial_epoch = int(ckpt.split('_')[1].split('.')[0])
else:
print("[INFO] No checkpoint found, starting new training")
model = create_model()
else:
model = create_model()
except Exception as e:
print(f"[ERROR] Failed to initialize model: {str(e)}")
return
# Display model summary
model.summary()
# Train model
try:
print("[INFO] Starting training...")
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=EPOCHS,
initial_epoch=initial_epoch,
callbacks=callbacks,
verbose=1
)
except Exception as e:
print(f"[ERROR] Training failed: {str(e)}")
traceback.print_exc()
finally:
model.save(OUTPUT_PATH)
print("[INFO] Training completed. Model saved to final_model_2.h5")
if __name__ == "__main__":
print("[INFO] Starting script...")
main()
print("[INFO] Script execution completed.")