Spaces:
Running
Running
import gradio as gr | |
from transformers import ViTForImageClassification, ViTImageProcessor | |
from PIL import Image | |
import torch | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load the model and processor from Hugging Face | |
model = ViTForImageClassification.from_pretrained("prithivMLmods/Deep-Fake-Detector-Model") | |
processor = ViTImageProcessor.from_pretrained("prithivMLmods/Deep-Fake-Detector-Model") | |
# Log model configuration to verify label mapping | |
logger.info(f"Model label mapping: {model.config.id2label}") | |
def detect(image, confidence_threshold=0.5): | |
"""Detect deepfake content using prithivMLmods/Deep-Fake-Detector-Model""" | |
if image is None: | |
raise gr.Error("Please upload an image to analyze") | |
try: | |
# Convert Gradio image (filepath) to PIL Image | |
pil_image = Image.open(image).convert("RGB") | |
# Resize to match ViT input requirements (224x224) | |
pil_image = pil_image.resize((224, 224), Image.Resampling.LANCZOS) | |
# Preprocess the image | |
inputs = processor(images=pil_image, return_tensors="pt") | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1)[0] | |
# Get confidence scores | |
confidence_real = probabilities[0].item() * 100 # Assuming 0 is Real | |
confidence_fake = probabilities[1].item() * 100 # Assuming 1 is Fake | |
# Verify label mapping from model config | |
id2label = model.config.id2label | |
predicted_class = torch.argmax(logits, dim=1).item() | |
predicted_label = id2label[predicted_class] | |
# Adjust prediction based on threshold and label | |
threshold_predicted = "Fake" if confidence_fake / 100 >= confidence_threshold else "Real" | |
confidence_score = max(confidence_real, confidence_fake) | |
# Log detailed output | |
logger.info(f"Logits: {logits.tolist()}") | |
logger.info(f"Probabilities - Real: {confidence_real:.1f}%, Fake: {confidence_fake:.1f}%") | |
logger.info(f"Predicted Class: {predicted_class}, Label: {predicted_label}") | |
logger.info(f"Threshold ({confidence_threshold}): {threshold_predicted}") | |
# Prepare output | |
overall = f"{confidence_score:.1f}% Confidence ({threshold_predicted})" | |
aigen = f"{confidence_fake:.1f}% (AI-Generated Content Likelihood)" | |
deepfake = f"{confidence_fake:.1f}% (Face Manipulation Likelihood)" | |
return overall, aigen, deepfake | |
except Exception as e: | |
logger.error(f"Error during analysis: {str(e)}") | |
raise gr.Error(f"Analysis error: {str(e)}") | |
# Custom CSS (unchanged) | |
custom_css = """ | |
.container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
font-family: 'Arial', sans-serif; | |
} | |
.header { | |
color: #2c3e50; | |
border-bottom: 2px solid #3498db; | |
padding-bottom: 10px; | |
} | |
.button-gradient { | |
background: linear-gradient(45deg, #3498db, #2ecc71, #9b59b6); | |
background-size: 400% 400%; | |
border: none; | |
padding: 12px 24px; | |
font-size: 16px; | |
font-weight: 600; | |
color: white; | |
border-radius: 8px; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
animation: gradientAnimation 3s ease infinite; | |
box-shadow: 0 2px 8px rgba(52, 152, 219, 0.3); | |
} | |
.button-gradient:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 4px 12px rgba(52, 152, 219, 0.5); | |
} | |
@keyframes gradientAnimation { | |
0% { background-position: 0% 50%; } | |
50% { background-position: 100% 50%; } | |
100% { background-position: 0% 50%; } | |
} | |
""" | |
MARKDOWN0 = """ | |
<div class="header"> | |
<h1>DeepFake Detection System</h1> | |
<p>Advanced AI-powered analysis for identifying manipulated media<br> | |
Powered by prithivMLmods/Deep-Fake-Detector-Model (Updated Jan 2025)<br> | |
Adjust threshold to tune sensitivity; check logs for detailed output</p> | |
</div> | |
""" | |
# Create Gradio interface with threshold slider | |
with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo: | |
gr.Markdown(MARKDOWN0) | |
with gr.Row(elem_classes="container"): | |
with gr.Column(scale=1): | |
image = gr.Image(type='filepath', height=400, label="Upload Image") | |
threshold = gr.Slider(0, 1, value=0.5, step=0.01, label="Confidence Threshold (Fake)") | |
detect_button = gr.Button("Analyze Image", elem_classes="button-gradient") | |
with gr.Column(scale=2): | |
overall = gr.Label(label="Confidence Score") | |
aigen = gr.Label(label="AI-Generated Content") | |
deepfake = gr.Label(label="Face Manipulation") | |
detect_button.click( | |
fn=detect, | |
inputs=[image, threshold], | |
outputs=[overall, aigen, deepfake] | |
) | |
# Launch the application | |
demo.launch( | |
debug=True | |
) |