Spaces:
Running
Running
File size: 5,132 Bytes
43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 27b4f7a 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 43655b0 43a9d75 e2a8667 43a9d75 727b540 43655b0 2b933d2 43655b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import torch
from PIL import Image as PILImage
from transformers import AutoImageProcessor, SiglipForImageClassification
import os
import warnings
# --- Configuration ---
MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Suppress specific warnings ---
# Suppress the specific PIL warning about potential decompression bombs
warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
# Suppress transformers warning about loading weights without specifying revision
warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
# --- Load Model and Processor (Load once at startup) ---
print(f"Using device: {DEVICE}")
print(f"Loading processor from: {MODEL_IDENTIFIER}")
try:
processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER)
print(f"Loading model from: {MODEL_IDENTIFIER}")
model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER)
model.to(DEVICE)
model.eval()
print("Model and processor loaded successfully.")
except Exception as e:
print(f"FATAL: Error loading model or processor: {e}")
# If the model fails to load, we raise an exception to stop the app
raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
# --- Prediction Function ---
def classify_image(image_pil):
"""
Classifies an image as AI-generated or Human-made.
Args:
image_pil (PIL.Image.Image): Input image in PIL format.
Returns:
dict: A dictionary mapping class labels ('ai', 'human') to their
confidence scores. Returns an empty dict if input is None.
"""
if image_pil is None:
# Handle case where the user clears the image input
print("Warning: No image provided.")
return {} # Return empty dict, Gradio Label handles this
print("Processing image...")
try:
# Ensure image is RGB
image = image_pil.convert("RGB")
# Preprocess using the loaded processor
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
# Perform inference
print("Running inference...")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get probabilities using softmax
# outputs.logits is shape [1, num_labels], softmax over the last dim
probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image
# Create a dictionary of label -> score
results = {}
for i, prob in enumerate(probabilities):
label = model.config.id2label[i]
results[label] = prob.item() # Use .item() to get Python float
print(f"Prediction results: {results}")
return results
except Exception as e:
print(f"Error during prediction: {e}")
# Optionally raise a Gradio error to show it in the UI
# raise gr.Error(f"Error processing image: {e}")
return {"Error": f"Processing failed: {e}"} # Or return an error message
# --- Gradio Interface Definition ---
# Define Example Images (Optional, but recommended)
# Create an 'examples' folder in your Space repo and put images there
example_dir = "examples"
example_images = []
if os.path.exists(example_dir):
for img_name in os.listdir(example_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
example_images.append(os.path.join(example_dir, img_name))
print(f"Found examples: {example_images}")
else:
print("No 'examples' directory found. Examples will not be shown.")
# Define the Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam", "clipboard"]), # Use PIL format as input
outputs=gr.Label(num_top_classes=2, label="Prediction Results"), # Use gr.Label for classification output
title="AI vs Human Image Detector",
description=(
f"Upload an image to classify if it was likely generated by AI or created by a human. "
f"Uses the `{MODEL_IDENTIFIER}` model on Hugging Face. Running on **{str(DEVICE).upper()}**."
),
article=(
"<div>"
"<p>This tool uses a SigLIP model fine-tuned for distinguishing between AI-generated and human-made images.</p>"
f"<p>Model Card: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a></p>"
"<p>Fine tuning code available at <a href='https://exnrt.com/blog/ai/fine-tuning-siglip2/' target='_blank'>https://exnrt.com/blog/ai/fine-tuning-siglip2/</a></p>"
"</div>"
),
examples=example_images if example_images else None, # Only add examples if found
cache_examples= True if example_images else False, # Cache results for examples if they exist
allow_flagging="never" # Or "auto" if you want users to flag issues
)
# --- Launch the App ---
if __name__ == "__main__":
print("Launching Gradio interface...")
iface.launch()
print("Gradio interface launched.") |