import gradio as gr import torch from PIL import Image as PILImage from transformers import AutoImageProcessor, SiglipForImageClassification import os import warnings # --- Configuration --- MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Suppress specific warnings --- # Suppress the specific PIL warning about potential decompression bombs warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.") # Suppress transformers warning about loading weights without specifying revision warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*") # --- Load Model and Processor (Load once at startup) --- print(f"Using device: {DEVICE}") print(f"Loading processor from: {MODEL_IDENTIFIER}") try: processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER) print(f"Loading model from: {MODEL_IDENTIFIER}") model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER) model.to(DEVICE) model.eval() print("Model and processor loaded successfully.") except Exception as e: print(f"FATAL: Error loading model or processor: {e}") # If the model fails to load, we raise an exception to stop the app raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e # --- Prediction Function --- def classify_image(image_pil): """ Classifies an image as AI-generated or Human-made. Args: image_pil (PIL.Image.Image): Input image in PIL format. Returns: dict: A dictionary mapping class labels ('ai', 'human') to their confidence scores. Returns an empty dict if input is None. """ if image_pil is None: # Handle case where the user clears the image input print("Warning: No image provided.") return {} # Return empty dict, Gradio Label handles this print("Processing image...") try: # Ensure image is RGB image = image_pil.convert("RGB") # Preprocess using the loaded processor inputs = processor(images=image, return_tensors="pt").to(DEVICE) # Perform inference print("Running inference...") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get probabilities using softmax # outputs.logits is shape [1, num_labels], softmax over the last dim probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image # Create a dictionary of label -> score results = {} for i, prob in enumerate(probabilities): label = model.config.id2label[i] results[label] = prob.item() # Use .item() to get Python float print(f"Prediction results: {results}") return results except Exception as e: print(f"Error during prediction: {e}") # Optionally raise a Gradio error to show it in the UI # raise gr.Error(f"Error processing image: {e}") return {"Error": f"Processing failed: {e}"} # Or return an error message # --- Gradio Interface Definition --- # Define Example Images (Optional, but recommended) # Create an 'examples' folder in your Space repo and put images there example_dir = "examples" example_images = [] if os.path.exists(example_dir): for img_name in os.listdir(example_dir): if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): example_images.append(os.path.join(example_dir, img_name)) print(f"Found examples: {example_images}") else: print("No 'examples' directory found. Examples will not be shown.") # Define the Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam", "clipboard"]), # Use PIL format as input outputs=gr.Label(num_top_classes=2, label="Prediction Results"), # Use gr.Label for classification output title="AI vs Human Image Detector", description=( f"Upload an image to classify if it was likely generated by AI or created by a human. " f"Uses the `{MODEL_IDENTIFIER}` model on Hugging Face. Running on **{str(DEVICE).upper()}**." ), article=( "
" "

This tool uses a SigLIP model fine-tuned for distinguishing between AI-generated and human-made images.

" f"

Model Card: {MODEL_IDENTIFIER}

" "

Fine tuning code available at https://exnrt.com/blog/ai/fine-tuning-siglip2/

" "
" ), examples=example_images if example_images else None, # Only add examples if found cache_examples= True if example_images else False, # Cache results for examples if they exist allow_flagging="never" # Or "auto" if you want users to flag issues ) # --- Launch the App --- if __name__ == "__main__": print("Launching Gradio interface...") iface.launch() print("Gradio interface launched.")