File size: 5,132 Bytes
43655b0
 
 
 
 
 
 
 
 
 
 
 
43a9d75
43655b0
43a9d75
43655b0
 
43a9d75
43655b0
 
 
 
 
 
 
 
 
 
 
 
43a9d75
43655b0
 
27b4f7a
43655b0
43a9d75
 
 
 
 
 
 
 
43655b0
43a9d75
43655b0
43a9d75
43655b0
 
 
43a9d75
43655b0
43a9d75
 
43655b0
 
43a9d75
43655b0
 
 
 
 
43a9d75
 
 
 
 
43655b0
 
 
43a9d75
43655b0
 
 
43a9d75
43655b0
 
43a9d75
 
 
43655b0
43a9d75
 
 
 
43655b0
 
43a9d75
43655b0
 
43a9d75
 
43655b0
43a9d75
 
 
 
 
 
 
 
 
 
43655b0
43a9d75
 
 
e2a8667
 
 
 
 
43a9d75
 
 
 
 
727b540
43655b0
 
 
2b933d2
43655b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from PIL import Image as PILImage
from transformers import AutoImageProcessor, SiglipForImageClassification
import os
import warnings

# --- Configuration ---
MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Suppress specific warnings ---
# Suppress the specific PIL warning about potential decompression bombs
warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
# Suppress transformers warning about loading weights without specifying revision
warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")


# --- Load Model and Processor (Load once at startup) ---
print(f"Using device: {DEVICE}")
print(f"Loading processor from: {MODEL_IDENTIFIER}")
try:
    processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER)
    print(f"Loading model from: {MODEL_IDENTIFIER}")
    model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER)
    model.to(DEVICE)
    model.eval()
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"FATAL: Error loading model or processor: {e}")
    # If the model fails to load, we raise an exception to stop the app
    raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e

# --- Prediction Function ---
def classify_image(image_pil):
    """
    Classifies an image as AI-generated or Human-made.
    Args:
        image_pil (PIL.Image.Image): Input image in PIL format.
    Returns:
        dict: A dictionary mapping class labels ('ai', 'human') to their
              confidence scores. Returns an empty dict if input is None.
    """
    if image_pil is None:
        # Handle case where the user clears the image input
        print("Warning: No image provided.")
        return {} # Return empty dict, Gradio Label handles this

    print("Processing image...")
    try:
        # Ensure image is RGB
        image = image_pil.convert("RGB")

        # Preprocess using the loaded processor
        inputs = processor(images=image, return_tensors="pt").to(DEVICE)

        # Perform inference
        print("Running inference...")
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        # Get probabilities using softmax
        # outputs.logits is shape [1, num_labels], softmax over the last dim
        probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image

        # Create a dictionary of label -> score
        results = {}
        for i, prob in enumerate(probabilities):
            label = model.config.id2label[i]
            results[label] = prob.item() # Use .item() to get Python float

        print(f"Prediction results: {results}")
        return results

    except Exception as e:
        print(f"Error during prediction: {e}")
        # Optionally raise a Gradio error to show it in the UI
        # raise gr.Error(f"Error processing image: {e}")
        return {"Error": f"Processing failed: {e}"} # Or return an error message

# --- Gradio Interface Definition ---

# Define Example Images (Optional, but recommended)
# Create an 'examples' folder in your Space repo and put images there
example_dir = "examples"
example_images = []
if os.path.exists(example_dir):
    for img_name in os.listdir(example_dir):
        if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
             example_images.append(os.path.join(example_dir, img_name))
    print(f"Found examples: {example_images}")
else:
    print("No 'examples' directory found. Examples will not be shown.")


# Define the Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam", "clipboard"]), # Use PIL format as input
    outputs=gr.Label(num_top_classes=2, label="Prediction Results"), # Use gr.Label for classification output
    title="AI vs Human Image Detector",
    description=(
        f"Upload an image to classify if it was likely generated by AI or created by a human. "
        f"Uses the `{MODEL_IDENTIFIER}` model on Hugging Face. Running on **{str(DEVICE).upper()}**."
    ),
    article=(
    "<div>"
    "<p>This tool uses a SigLIP model fine-tuned for distinguishing between AI-generated and human-made images.</p>"
    f"<p>Model Card: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a></p>"
    "<p>Fine tuning code available at <a href='https://exnrt.com/blog/ai/fine-tuning-siglip2/' target='_blank'>https://exnrt.com/blog/ai/fine-tuning-siglip2/</a></p>"
    "</div>"
    ),
    examples=example_images if example_images else None, # Only add examples if found
    cache_examples= True if example_images else False, # Cache results for examples if they exist
    allow_flagging="never" # Or "auto" if you want users to flag issues
)

# --- Launch the App ---
if __name__ == "__main__":
    print("Launching Gradio interface...")
    iface.launch()
    print("Gradio interface launched.")