AI-RESEARCHER-2024's picture
Update app.py
002641d verified
raw
history blame
5.37 kB
import gradio as gr
import cv2
import numpy as np
import tensorflow as tf
import os
# Load the trained models using Keras
cnn_model = tf.keras.models.load_model('cnn_model.h5')
qcnn_model = tf.keras.models.load_model('qcnn_model.h5')
# Directories containing example videos
examples_dir = 'examples'
original_dir = os.path.join(examples_dir, 'Original')
deepfake_roop_dir = os.path.join(examples_dir, 'DeepfakeRoop')
deepfake_web_dir = os.path.join(examples_dir, 'DeepfakeWeb')
# Function to get video paths from a directory
def get_video_paths(directory, label):
videos = []
for vid in os.listdir(directory):
if vid.endswith('.mp4'):
videos.append({'path': os.path.join(directory, vid), 'label': label})
return videos
# Get video paths for each category
original_videos = get_video_paths(original_dir, 'Original')
deepfake_roop_videos = get_video_paths(deepfake_roop_dir, 'DeepfakeRoop')
deepfake_web_videos = get_video_paths(deepfake_web_dir, 'DeepfakeWeb')
# Combine all examples
examples = original_videos + deepfake_roop_videos + deepfake_web_videos
# Map from example video path to label
example_videos_dict = {example['path']: example['label'] for example in examples}
def process_video(video_path, true_label=None):
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
if fps == 0 or np.isnan(fps):
fps = 30
frame_interval = max(int(round(fps / 30)), 1)
frame_count = 0
sampled_frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
resized_frame = cv2.resize(gray_frame, (28, 28))
normalized_frame = resized_frame / 255.0
normalized_frame = normalized_frame.reshape(1, 28, 28, 1)
sampled_frames.append(normalized_frame)
frame_count += 1
cap.release()
cnn_correct = 0
qcnn_correct = 0
cnn_class0 = 0
cnn_class1 = 0
qcnn_class0 = 0
qcnn_class1 = 0
total_frames = len(sampled_frames)
for frame in sampled_frames:
cnn_pred = cnn_model.predict(frame)
cnn_label = np.argmax(cnn_pred)
if cnn_label == 0:
cnn_class0 += 1
else:
cnn_class1 += 1
if true_label is not None and cnn_label == true_label:
cnn_correct += 1
qcnn_pred = qcnn_model.predict(frame)
qcnn_label = np.argmax(qcnn_pred)
if qcnn_label == 0:
qcnn_class0 += 1
else:
qcnn_class1 += 1
if true_label is not None and qcnn_label == true_label:
qcnn_correct += 1
if total_frames > 0:
cnn_class0_percent = (cnn_class0 / total_frames) * 100
cnn_class1_percent = (cnn_class1 / total_frames) * 100
qcnn_class0_percent = (qcnn_class0 / total_frames) * 100
qcnn_class1_percent = (qcnn_class1 / total_frames) * 100
else:
cnn_class0_percent = cnn_class1_percent = qcnn_class0_percent = qcnn_class1_percent = 0
if true_label is not None:
# Calculate accuracy if true_label is provided (example video)
cnn_accuracy = (cnn_correct / total_frames) * 100 if total_frames > 0 else 0
qcnn_accuracy = (qcnn_correct / total_frames) * 100 if total_frames > 0 else 0
result = f"CNN Model Accuracy: {cnn_accuracy:.2f}%\n"
result += f"QCNN Model Accuracy: {qcnn_accuracy:.2f}%"
else:
# Display percent of frames classified from each class
result = f"CNN Model Predictions:\nClass 0: {cnn_class0_percent:.2f}%\nClass 1: {cnn_class1_percent:.2f}%\n"
result += f"QCNN Model Predictions:\nClass 0: {qcnn_class0_percent:.2f}%\nClass 1: {qcnn_class1_percent:.2f}%"
return result
def predict(video_input):
if video_input is None:
return "Please upload a video or select an example."
if isinstance(video_input, dict):
video_path = video_input['name']
elif isinstance(video_input, str):
video_path = video_input
else:
return "Invalid video input."
# Check if video is an example
if video_path in example_videos_dict:
label = example_videos_dict[video_path]
if label == 'Original':
true_label = 0
else:
true_label = 1
result = process_video(video_path, true_label=true_label)
result = f"Example Video Detected ({label})\n" + result
else:
result = process_video(video_path)
return result
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Quanvolutional Neural Networks for Deepfake Detection</h1>")
gr.Markdown("<h2 style='text-align: center;'>Steven Fernandes, Ph.D.</h2>")
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload Video or Select an Example", type="filepath")
gr.Examples(
examples=[example['path'] for example in examples],
inputs=video_input,
label="Examples"
)
predict_button = gr.Button("Predict")
with gr.Column():
output = gr.Textbox(label="Result")
predict_button.click(fn=predict, inputs=video_input, outputs=output)
demo.launch()