Ateeqq commited on
Commit
43a9d75
·
verified ·
1 Parent(s): c892189

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -146
app.py CHANGED
@@ -10,9 +10,12 @@ MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector"
10
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # --- Suppress specific warnings ---
 
13
  warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
 
14
  warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
15
 
 
16
  # --- Load Model and Processor (Load once at startup) ---
17
  print(f"Using device: {DEVICE}")
18
  print(f"Loading processor from: {MODEL_IDENTIFIER}")
@@ -25,184 +28,93 @@ try:
25
  print("Model and processor loaded successfully.")
26
  except Exception as e:
27
  print(f"FATAL: Error loading model or processor: {e}")
 
28
  raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
29
 
30
  # --- Prediction Function ---
31
  def classify_image(image_pil):
 
 
 
 
 
 
 
 
32
  if image_pil is None:
 
33
  print("Warning: No image provided.")
34
- return {}
35
 
36
  print("Processing image...")
37
  try:
 
38
  image = image_pil.convert("RGB")
 
 
39
  inputs = processor(images=image, return_tensors="pt").to(DEVICE)
40
 
 
41
  print("Running inference...")
42
  with torch.no_grad():
43
  outputs = model(**inputs)
44
  logits = outputs.logits
45
 
46
- probabilities = torch.softmax(logits, dim=-1)[0]
 
 
 
 
47
  results = {}
48
  for i, prob in enumerate(probabilities):
49
  label = model.config.id2label[i]
50
- results[label] = round(prob.item(), 4)
51
 
52
  print(f"Prediction results: {results}")
53
  return results
 
54
  except Exception as e:
55
  print(f"Error during prediction: {e}")
56
- return {"Error": f"Processing failed. Please try again or use a different image."}
 
 
57
 
58
- # --- Define Example Images ---
 
 
 
59
  example_dir = "examples"
60
  example_images = []
61
- if os.path.exists(example_dir) and os.listdir(example_dir):
62
  for img_name in os.listdir(example_dir):
63
  if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
64
- example_images.append(os.path.join(example_dir, img_name))
65
- if example_images:
66
- print(f"Found examples: {example_images}")
67
- else:
68
- print("No valid image files found in 'examples' directory.")
69
  else:
70
- print("No 'examples' directory found or it's empty. Examples will not be shown.")
71
-
72
- # --- Custom CSS for Dark Theme Adjustments ---
73
- # Minimal CSS - let the dark theme handle most things
74
- css = """
75
- body { font-family: 'Inter', sans-serif; }
76
-
77
- /* Style the main title */
78
- #app-title {
79
- text-align: center;
80
- font-weight: bold;
81
- font-size: 2.5em;
82
- margin-bottom: 5px;
83
- /* color removed - let theme handle */
84
- }
85
-
86
- /* Style the description */
87
- #app-description {
88
- text-align: center;
89
- font-size: 1.1em;
90
- margin-bottom: 25px;
91
- /* color removed - let theme handle */
92
- }
93
- #app-description code { /* Style model name - theme might handle this, but can force */
94
- font-weight: bold;
95
- background-color: rgba(255, 255, 255, 0.1); /* Slightly lighter background for code */
96
- padding: 2px 5px;
97
- border-radius: 4px;
98
- color: #c5f7dc; /* Light green text for code block */
99
- }
100
- #app-description strong { /* Style device name */
101
- color: #2dd4bf; /* Brighter teal/emerald for dark theme */
102
- font-weight: bold;
103
- }
104
-
105
- /* Style the results heading */
106
- #results-heading {
107
- text-align: center;
108
- font-size: 1.2em;
109
- margin-bottom: 10px;
110
- /* color removed - let theme handle */
111
- }
112
-
113
- /* Add some definition to input/output columns if needed */
114
- #input-column, #output-column {
115
- border: 1px solid #4b5563; /* Darker border for dark theme */
116
- border-radius: 12px;
117
- padding: 20px;
118
- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow, works on dark too */
119
- /* background-color removed - let theme handle */
120
- }
121
-
122
- /* Ensure label text inside columns is readable */
123
- #prediction-label .label-name { font-weight: bold; font-size: 1.1em; }
124
- #prediction-label .confidence { font-size: 1em; }
125
-
126
-
127
- /* Footer styling */
128
- #app-footer {
129
- margin-top: 40px;
130
- padding-top: 20px;
131
- border-top: 1px solid #374151; /* Darker border for footer */
132
- text-align: center;
133
- font-size: 0.9em;
134
- /* color removed - let theme handle */
135
- }
136
- #app-footer a {
137
- color: #60a5fa; /* Lighter blue for links */
138
- text-decoration: none;
139
- }
140
- #app-footer a:hover {
141
- text-decoration: underline;
142
- }
143
- """
144
-
145
- # --- Gradio Interface using Blocks and Theme ---
146
- # Use the theme string identifier for the dark mode variant
147
- # Other options: "default/dark", "monochrome/dark", "glass/dark"
148
- with gr.Blocks(theme="soft/dark", css=css) as iface: # <<< CHANGE IS HERE
149
- # Title and Description
150
- gr.Markdown("# AI vs Human Image Detector", elem_id="app-title")
151
- gr.Markdown(
152
  f"Upload an image to classify if it was likely generated by AI or created by a human. "
153
- f"Uses the `{MODEL_IDENTIFIER}` model. Running on **{str(DEVICE).upper()}**.",
154
- elem_id="app-description"
155
- )
156
-
157
- # Main layout
158
- with gr.Row(variant='panel'):
159
- with gr.Column(scale=1, min_width=300, elem_id="input-column"):
160
- image_input = gr.Image(
161
- type="pil",
162
- label="🖼️ Upload Your Image",
163
- sources=["upload", "webcam", "clipboard"],
164
- height=400,
165
- )
166
- submit_button = gr.Button("🔍 Classify Image", variant="primary")
167
-
168
- with gr.Column(scale=1, min_width=300, elem_id="output-column"):
169
- gr.Markdown("📊 **Prediction Results**", elem_id="results-heading")
170
- result_output = gr.Label(
171
- num_top_classes=2,
172
- label="Classification",
173
- elem_id="prediction-label"
174
- )
175
-
176
- # Examples Section
177
- if example_images:
178
- gr.Examples(
179
- examples=example_images,
180
- inputs=image_input,
181
- outputs=result_output,
182
- fn=classify_image,
183
- cache_examples=True,
184
- label="✨ Click an Example to Try!"
185
- )
186
-
187
- # Footer / Article section
188
- gr.Markdown(f"""
189
- ---
190
- **How it Works:**
191
- This application uses a fine-tuned [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) vision model
192
- specifically trained to differentiate between images generated by Artificial Intelligence and those created by humans.
193
-
194
- **Model:**
195
- * You can find the model card here: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a>
196
-
197
- **Training Code:**
198
- Fine tuning code available at [https://exnrt.com/blog/ai/fine-tuning-siglip2/](https://exnrt.com/blog/ai/fine-tuning-siglip2/).
199
- """,
200
- elem_id="app-footer"
201
- )
202
-
203
- # Connect events
204
- submit_button.click(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_button")
205
- image_input.change(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_change")
206
 
207
  # --- Launch the App ---
208
  if __name__ == "__main__":
 
10
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # --- Suppress specific warnings ---
13
+ # Suppress the specific PIL warning about potential decompression bombs
14
  warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
15
+ # Suppress transformers warning about loading weights without specifying revision
16
  warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
17
 
18
+
19
  # --- Load Model and Processor (Load once at startup) ---
20
  print(f"Using device: {DEVICE}")
21
  print(f"Loading processor from: {MODEL_IDENTIFIER}")
 
28
  print("Model and processor loaded successfully.")
29
  except Exception as e:
30
  print(f"FATAL: Error loading model or processor: {e}")
31
+ # If the model fails to load, we raise an exception to stop the app
32
  raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
33
 
34
  # --- Prediction Function ---
35
  def classify_image(image_pil):
36
+ """
37
+ Classifies an image as AI-generated or Human-made.
38
+ Args:
39
+ image_pil (PIL.Image.Image): Input image in PIL format.
40
+ Returns:
41
+ dict: A dictionary mapping class labels ('ai', 'human') to their
42
+ confidence scores. Returns an empty dict if input is None.
43
+ """
44
  if image_pil is None:
45
+ # Handle case where the user clears the image input
46
  print("Warning: No image provided.")
47
+ return {} # Return empty dict, Gradio Label handles this
48
 
49
  print("Processing image...")
50
  try:
51
+ # Ensure image is RGB
52
  image = image_pil.convert("RGB")
53
+
54
+ # Preprocess using the loaded processor
55
  inputs = processor(images=image, return_tensors="pt").to(DEVICE)
56
 
57
+ # Perform inference
58
  print("Running inference...")
59
  with torch.no_grad():
60
  outputs = model(**inputs)
61
  logits = outputs.logits
62
 
63
+ # Get probabilities using softmax
64
+ # outputs.logits is shape [1, num_labels], softmax over the last dim
65
+ probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image
66
+
67
+ # Create a dictionary of label -> score
68
  results = {}
69
  for i, prob in enumerate(probabilities):
70
  label = model.config.id2label[i]
71
+ results[label] = prob.item() # Use .item() to get Python float
72
 
73
  print(f"Prediction results: {results}")
74
  return results
75
+
76
  except Exception as e:
77
  print(f"Error during prediction: {e}")
78
+ # Optionally raise a Gradio error to show it in the UI
79
+ # raise gr.Error(f"Error processing image: {e}")
80
+ return {"Error": f"Processing failed: {e}"} # Or return an error message
81
 
82
+ # --- Gradio Interface Definition ---
83
+
84
+ # Define Example Images (Optional, but recommended)
85
+ # Create an 'examples' folder in your Space repo and put images there
86
  example_dir = "examples"
87
  example_images = []
88
+ if os.path.exists(example_dir):
89
  for img_name in os.listdir(example_dir):
90
  if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
91
+ example_images.append(os.path.join(example_dir, img_name))
92
+ print(f"Found examples: {example_images}")
 
 
 
93
  else:
94
+ print("No 'examples' directory found. Examples will not be shown.")
95
+
96
+
97
+ # Define the Gradio interface
98
+ iface = gr.Interface(
99
+ fn=classify_image,
100
+ inputs=gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam", "clipboard"]), # Use PIL format as input
101
+ outputs=gr.Label(num_top_classes=2, label="Prediction Results"), # Use gr.Label for classification output
102
+ title="AI vs Human Image Detector",
103
+ description=(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  f"Upload an image to classify if it was likely generated by AI or created by a human. "
105
+ f"Uses the `{MODEL_IDENTIFIER}` model on Hugging Face. Running on **{str(DEVICE).upper()}**."
106
+ ),
107
+ article=(
108
+ "<div>"
109
+ "<p>This tool uses a SigLIP model fine-tuned for distinguishing between AI-generated and human-made images.</p>"
110
+ f"<p>Model Card: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a></p>"
111
+ Fine tuning code available at [https://exnrt.com/blog/ai/fine-tuning-siglip2/](https://exnrt.com/blog/ai/fine-tuning-siglip2/).
112
+ "</div>"
113
+ ),
114
+ examples=example_images if example_images else None, # Only add examples if found
115
+ cache_examples= True if example_images else False, # Cache results for examples if they exist
116
+ allow_flagging="never" # Or "auto" if you want users to flag issues
117
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # --- Launch the App ---
120
  if __name__ == "__main__":