import gradio as gr import torch from PIL import Image as PILImage from transformers import AutoImageProcessor, SiglipForImageClassification import os import warnings # --- Configuration --- MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Suppress specific warnings --- # Suppress the specific PIL warning about potential decompression bombs warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.") # Suppress transformers warning about loading weights without specifying revision warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*") # --- Load Model and Processor (Load once at startup) --- print(f"Using device: {DEVICE}") print(f"Loading processor from: {MODEL_IDENTIFIER}") try: processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER) print(f"Loading model from: {MODEL_IDENTIFIER}") model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER) model.to(DEVICE) model.eval() print("Model and processor loaded successfully.") except Exception as e: print(f"FATAL: Error loading model or processor: {e}") # If the model fails to load, we raise an exception to stop the app raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e # --- Prediction Function --- def classify_image(image_pil): """ Classifies an image as AI-generated or Human-made. Args: image_pil (PIL.Image.Image): Input image in PIL format. Returns: dict: A dictionary mapping class labels ('ai', 'human') to their confidence scores. Returns an empty dict if input is None. """ if image_pil is None: # Handle case where the user clears the image input print("Warning: No image provided.") return {} # Return empty dict, Gradio Label handles this print("Processing image...") try: # Ensure image is RGB image = image_pil.convert("RGB") # Preprocess using the loaded processor inputs = processor(images=image, return_tensors="pt").to(DEVICE) # Perform inference print("Running inference...") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get probabilities using softmax # outputs.logits is shape [1, num_labels], softmax over the last dim probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image # Create a dictionary of label -> score results = {} for i, prob in enumerate(probabilities): label = model.config.id2label[i] results[label] = round(prob.item(), 4) # Round for cleaner display print(f"Prediction results: {results}") return results except Exception as e: print(f"Error during prediction: {e}") # Return error in the format expected by gr.Label # Provide a user-friendly error message in the output return {"Error": f"Processing failed. Please try again or use a different image."} # --- Define Example Images --- example_dir = "examples" example_images = [] if os.path.exists(example_dir) and os.listdir(example_dir): # Check if dir exists AND is not empty for img_name in os.listdir(example_dir): if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): example_images.append(os.path.join(example_dir, img_name)) if example_images: print(f"Found examples: {example_images}") else: print("No valid image files found in 'examples' directory.") else: print("No 'examples' directory found or it's empty. Examples will not be shown.") # --- Custom CSS --- css = """ body { font-family: 'Inter', sans-serif; } /* Use a clean sans-serif font */ /* Style the main title */ #app-title { text-align: center; font-weight: bold; font-size: 2.5em; /* Larger title */ margin-bottom: 5px; /* Reduced space below title */ color: #2c3e50; /* Darker color */ } /* Style the description */ #app-description { text-align: center; font-size: 1.1em; margin-bottom: 25px; /* More space below description */ color: #576574; /* Subdued color */ } #app-description code { /* Style model name */ font-weight: bold; background-color: #f1f2f6; padding: 2px 5px; border-radius: 4px; } #app-description strong { /* Style device name */ color: #1abc9c; /* Highlight color for device */ } /* Style the results area */ #prediction-label .label-name { font-weight: bold; font-size: 1.1em; } #prediction-label .confidence { font-size: 1em; } /* Style the results heading */ #results-heading { text-align: center; font-size: 1.2em; /* Slightly larger heading for results */ margin-bottom: 10px; /* Space below heading */ color: #34495e; /* Match other heading colors */ } /* Style the examples section */ .gradio-container .examples-container { padding-top: 15px; } .gradio-container .examples-header { font-size: 1.1em; font-weight: bold; margin-bottom: 10px; color: #34495e; } /* Add a subtle border/shadow to input/output columns for definition */ #input-column, #output-column { border: 1px solid #e0e0e0; border-radius: 12px; /* More rounded corners */ padding: 20px; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* Subtle shadow */ background-color: #ffffff; /* Ensure white background */ } /* Footer styling */ #app-footer { margin-top: 40px; padding-top: 20px; border-top: 1px solid #dfe6e9; text-align: center; font-size: 0.9em; color: #8395a7; } #app-footer a { color: #3498db; text-decoration: none; } #app-footer a:hover { text-decoration: underline; } """ # --- Gradio Interface using Blocks and Theme --- # Choose a theme: gr.themes.Soft(), gr.themes.Monochrome(), gr.themes.Glass(), etc. theme = gr.themes.Soft( primary_hue="emerald", # Color scheme based on emerald green secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_lg, # Larger corner radius spacing_size=gr.themes.sizes.spacing_lg, # More spacing ).set( # Further fine-tuning body_background_fill="#f8f9fa", # Very light grey background block_radius="12px", ) with gr.Blocks(theme=theme, css=css) as iface: # Title and Description using Markdown for better formatting gr.Markdown("# AI vs Human Image Detector", elem_id="app-title") gr.Markdown( f"Upload an image to classify if it was likely generated by AI or created by a human. " f"Uses the `{MODEL_IDENTIFIER}` model. Running on **{str(DEVICE).upper()}**.", elem_id="app-description" ) # Main layout with Input and Output side-by-side with gr.Row(variant='panel'): # 'panel' adds a light border/background with gr.Column(scale=1, min_width=300, elem_id="input-column"): image_input = gr.Image( type="pil", label="🖼️ Upload Your Image", sources=["upload", "webcam", "clipboard"], height=400, # Adjust height as needed ) submit_button = gr.Button("🔍 Classify Image", variant="primary") # Make button prominent with gr.Column(scale=1, min_width=300, elem_id="output-column"): # Use elem_id and target with CSS for styling gr.Markdown("📊 **Prediction Results**", elem_id="results-heading") result_output = gr.Label( num_top_classes=2, label="Classification", elem_id="prediction-label" ) # Examples Section if example_images: # Only show examples if they exist and list is not empty gr.Examples( examples=example_images, inputs=image_input, outputs=result_output, fn=classify_image, cache_examples=True, # Caching is good for static examples label="✨ Click an Example to Try!" ) # Footer / Article section gr.Markdown( """ --- This application uses a fine-tuned [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) vision model specifically trained to differentiate between images generated by Artificial Intelligence and those created by humans. You can find the model card here: {model_id} Fine tuning code available at [https://exnrt.com/blog/ai/fine-tuning-siglip2/](https://exnrt.com/blog/ai/fine-tuning-siglip2/). """.format(model_id=MODEL_IDENTIFIER), elem_id="app-footer" ) # Connect the button click or image change to the prediction function # Use api_name for potential API usage later submit_button.click(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_button") image_input.change(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_change") # --- Launch the App --- if __name__ == "__main__": print("Launching Gradio interface...") iface.launch() # Add share=True for temporary public link if needed: iface.launch(share=True) print("Gradio interface launched.")