Ateeqq commited on
Commit
43655b0
·
verified ·
1 Parent(s): 47e4735

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image as PILImage
4
+ from transformers import AutoImageProcessor, SiglipForImageClassification
5
+ import os
6
+ import warnings
7
+
8
+ # --- Configuration ---
9
+ MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector"
10
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # --- Suppress specific warnings ---
13
+ # Suppress the specific PIL warning about potential decompression bombs
14
+ warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
15
+ # Suppress transformers warning about loading weights without specifying revision
16
+ warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
17
+
18
+
19
+ # --- Load Model and Processor (Load once at startup) ---
20
+ print(f"Using device: {DEVICE}")
21
+ print(f"Loading processor from: {MODEL_IDENTIFIER}")
22
+ try:
23
+ processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER)
24
+ print(f"Loading model from: {MODEL_IDENTIFIER}")
25
+ model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER)
26
+ model.to(DEVICE)
27
+ model.eval()
28
+ print("Model and processor loaded successfully.")
29
+ except Exception as e:
30
+ print(f"FATAL: Error loading model or processor: {e}")
31
+ # If the model fails to load, we raise an exception to stop the app
32
+ raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
33
+
34
+ # --- Prediction Function ---
35
+ def classify_image(image_pil):
36
+ """
37
+ Classifies an image as AI-generated or Human-made.
38
+
39
+ Args:
40
+ image_pil (PIL.Image.Image): Input image in PIL format.
41
+
42
+ Returns:
43
+ dict: A dictionary mapping class labels ('ai', 'human') to their
44
+ confidence scores. Returns an empty dict if input is None.
45
+ """
46
+ if image_pil is None:
47
+ # Handle case where the user clears the image input
48
+ print("Warning: No image provided.")
49
+ return {} # Return empty dict, Gradio Label handles this
50
+
51
+ print("Processing image...")
52
+ try:
53
+ # Ensure image is RGB
54
+ image = image_pil.convert("RGB")
55
+
56
+ # Preprocess using the loaded processor
57
+ inputs = processor(images=image, return_tensors="pt").to(DEVICE)
58
+
59
+ # Perform inference
60
+ print("Running inference...")
61
+ with torch.no_grad():
62
+ outputs = model(**inputs)
63
+ logits = outputs.logits
64
+
65
+ # Get probabilities using softmax
66
+ # outputs.logits is shape [1, num_labels], softmax over the last dim
67
+ probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image
68
+
69
+ # Create a dictionary of label -> score
70
+ results = {}
71
+ for i, prob in enumerate(probabilities):
72
+ label = model.config.id2label[i]
73
+ results[label] = prob.item() # Use .item() to get Python float
74
+
75
+ print(f"Prediction results: {results}")
76
+ return results
77
+
78
+ except Exception as e:
79
+ print(f"Error during prediction: {e}")
80
+ # Optionally raise a Gradio error to show it in the UI
81
+ # raise gr.Error(f"Error processing image: {e}")
82
+ return {"Error": f"Processing failed: {e}"} # Or return an error message
83
+
84
+ # --- Gradio Interface Definition ---
85
+
86
+ # Define Example Images (Optional, but recommended)
87
+ # Create an 'examples' folder in your Space repo and put images there
88
+ example_dir = "examples"
89
+ example_images = []
90
+ if os.path.exists(example_dir):
91
+ for img_name in os.listdir(example_dir):
92
+ if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
93
+ example_images.append(os.path.join(example_dir, img_name))
94
+ print(f"Found examples: {example_images}")
95
+ else:
96
+ print("No 'examples' directory found. Examples will not be shown.")
97
+
98
+
99
+ # Define the Gradio interface
100
+ iface = gr.Interface(
101
+ fn=classify_image,
102
+ inputs=gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam", "clipboard"]), # Use PIL format as input
103
+ outputs=gr.Label(num_top_classes=2, label="Prediction Results"), # Use gr.Label for classification output
104
+ title="AI vs Human Image Detector",
105
+ description=(
106
+ f"Upload an image to classify if it was likely generated by AI or created by a human. "
107
+ f"Uses the `{MODEL_IDENTIFIER}` model on Hugging Face. Running on **{str(DEVICE).upper()}**."
108
+ ),
109
+ article=(
110
+ "<div>"
111
+ "<p>This tool uses a SigLIP model fine-tuned for distinguishing between AI-generated and human-made images.</p>"
112
+ f"<p>Model Card: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a></p>"
113
+ "<p style='text-align: center;'>App created using Gradio and Hugging Face Transformers.</p>"
114
+ "</div>"
115
+ ),
116
+ examples=example_images if example_images else None, # Only add examples if found
117
+ cache_examples= True if example_images else False, # Cache results for examples if they exist
118
+ allow_flagging="never" # Or "auto" if you want users to flag issues
119
+ )
120
+
121
+ # --- Launch the App ---
122
+ if __name__ == "__main__":
123
+ print("Launching Gradio interface...")
124
+ iface.launch()
125
+ print("Gradio interface launched.")