Ateeqq's picture
Update app.py
e2a8667 verified
import gradio as gr
import torch
from PIL import Image as PILImage
from transformers import AutoImageProcessor, SiglipForImageClassification
import os
import warnings
# --- Configuration ---
MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Suppress specific warnings ---
# Suppress the specific PIL warning about potential decompression bombs
warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
# Suppress transformers warning about loading weights without specifying revision
warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
# --- Load Model and Processor (Load once at startup) ---
print(f"Using device: {DEVICE}")
print(f"Loading processor from: {MODEL_IDENTIFIER}")
try:
processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER)
print(f"Loading model from: {MODEL_IDENTIFIER}")
model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER)
model.to(DEVICE)
model.eval()
print("Model and processor loaded successfully.")
except Exception as e:
print(f"FATAL: Error loading model or processor: {e}")
# If the model fails to load, we raise an exception to stop the app
raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
# --- Prediction Function ---
def classify_image(image_pil):
"""
Classifies an image as AI-generated or Human-made.
Args:
image_pil (PIL.Image.Image): Input image in PIL format.
Returns:
dict: A dictionary mapping class labels ('ai', 'human') to their
confidence scores. Returns an empty dict if input is None.
"""
if image_pil is None:
# Handle case where the user clears the image input
print("Warning: No image provided.")
return {} # Return empty dict, Gradio Label handles this
print("Processing image...")
try:
# Ensure image is RGB
image = image_pil.convert("RGB")
# Preprocess using the loaded processor
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
# Perform inference
print("Running inference...")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get probabilities using softmax
# outputs.logits is shape [1, num_labels], softmax over the last dim
probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image
# Create a dictionary of label -> score
results = {}
for i, prob in enumerate(probabilities):
label = model.config.id2label[i]
results[label] = prob.item() # Use .item() to get Python float
print(f"Prediction results: {results}")
return results
except Exception as e:
print(f"Error during prediction: {e}")
# Optionally raise a Gradio error to show it in the UI
# raise gr.Error(f"Error processing image: {e}")
return {"Error": f"Processing failed: {e}"} # Or return an error message
# --- Gradio Interface Definition ---
# Define Example Images (Optional, but recommended)
# Create an 'examples' folder in your Space repo and put images there
example_dir = "examples"
example_images = []
if os.path.exists(example_dir):
for img_name in os.listdir(example_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
example_images.append(os.path.join(example_dir, img_name))
print(f"Found examples: {example_images}")
else:
print("No 'examples' directory found. Examples will not be shown.")
# Define the Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam", "clipboard"]), # Use PIL format as input
outputs=gr.Label(num_top_classes=2, label="Prediction Results"), # Use gr.Label for classification output
title="AI vs Human Image Detector",
description=(
f"Upload an image to classify if it was likely generated by AI or created by a human. "
f"Uses the `{MODEL_IDENTIFIER}` model on Hugging Face. Running on **{str(DEVICE).upper()}**."
),
article=(
"<div>"
"<p>This tool uses a SigLIP model fine-tuned for distinguishing between AI-generated and human-made images.</p>"
f"<p>Model Card: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a></p>"
"<p>Fine tuning code available at <a href='https://exnrt.com/blog/ai/fine-tuning-siglip2/' target='_blank'>https://exnrt.com/blog/ai/fine-tuning-siglip2/</a></p>"
"</div>"
),
examples=example_images if example_images else None, # Only add examples if found
cache_examples= True if example_images else False, # Cache results for examples if they exist
allow_flagging="never" # Or "auto" if you want users to flag issues
)
# --- Launch the App ---
if __name__ == "__main__":
print("Launching Gradio interface...")
iface.launch()
print("Gradio interface launched.")