Ateeqq commited on
Commit
27b4f7a
Β·
verified Β·
1 Parent(s): 727b540

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -21
app.py CHANGED
@@ -10,7 +10,9 @@ MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector"
10
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # --- Suppress specific warnings ---
 
13
  warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
 
14
  warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
15
 
16
 
@@ -26,9 +28,10 @@ try:
26
  print("Model and processor loaded successfully.")
27
  except Exception as e:
28
  print(f"FATAL: Error loading model or processor: {e}")
 
29
  raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
30
 
31
- # --- Prediction Function (No changes needed) ---
32
  def classify_image(image_pil):
33
  """
34
  Classifies an image as AI-generated or Human-made.
@@ -41,21 +44,29 @@ def classify_image(image_pil):
41
  confidence scores. Returns an empty dict if input is None.
42
  """
43
  if image_pil is None:
 
44
  print("Warning: No image provided.")
45
- return {}
46
 
47
  print("Processing image...")
48
  try:
 
49
  image = image_pil.convert("RGB")
 
 
50
  inputs = processor(images=image, return_tensors="pt").to(DEVICE)
51
 
 
52
  print("Running inference...")
53
  with torch.no_grad():
54
  outputs = model(**inputs)
55
  logits = outputs.logits
56
 
57
- probabilities = torch.softmax(logits, dim=-1)[0]
 
 
58
 
 
59
  results = {}
60
  for i, prob in enumerate(probabilities):
61
  label = model.config.id2label[i]
@@ -67,22 +78,25 @@ def classify_image(image_pil):
67
  except Exception as e:
68
  print(f"Error during prediction: {e}")
69
  # Return error in the format expected by gr.Label
70
- return {"Error": f"Processing failed"}
 
71
 
72
  # --- Define Example Images ---
73
  example_dir = "examples"
74
  example_images = []
75
- if os.path.exists(example_dir):
76
  for img_name in os.listdir(example_dir):
77
  if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
78
  example_images.append(os.path.join(example_dir, img_name))
79
- print(f"Found examples: {example_images}")
 
 
 
80
  else:
81
  print("No 'examples' directory found or it's empty. Examples will not be shown.")
82
 
83
 
84
  # --- Custom CSS ---
85
- # You can experiment with different CSS here
86
  css = """
87
  body { font-family: 'Inter', sans-serif; } /* Use a clean sans-serif font */
88
 
@@ -116,6 +130,14 @@ body { font-family: 'Inter', sans-serif; } /* Use a clean sans-serif font */
116
  #prediction-label .label-name { font-weight: bold; font-size: 1.1em; }
117
  #prediction-label .confidence { font-size: 1em; }
118
 
 
 
 
 
 
 
 
 
119
  /* Style the examples section */
120
  .gradio-container .examples-container { padding-top: 15px; }
121
  .gradio-container .examples-header { font-size: 1.1em; font-weight: bold; margin-bottom: 10px; color: #34495e; }
@@ -143,8 +165,7 @@ body { font-family: 'Inter', sans-serif; } /* Use a clean sans-serif font */
143
  """
144
 
145
  # --- Gradio Interface using Blocks and Theme ---
146
- # Choose a theme: gr.themes.Soft(), gr.themes.Monochrome(), gr.themes.Glass()
147
- # Or customize the default: gr.themes.Default().set(radius_size="sm", spacing_size="sm")
148
  theme = gr.themes.Soft(
149
  primary_hue="emerald", # Color scheme based on emerald green
150
  secondary_hue="blue",
@@ -153,7 +174,7 @@ theme = gr.themes.Soft(
153
  spacing_size=gr.themes.sizes.spacing_lg, # More spacing
154
  ).set(
155
  # Further fine-tuning
156
- body_background_fill="#f1f2f6", # Light grey background
157
  block_radius="12px",
158
  )
159
 
@@ -179,7 +200,8 @@ with gr.Blocks(theme=theme, css=css) as iface:
179
  submit_button = gr.Button("πŸ” Classify Image", variant="primary") # Make button prominent
180
 
181
  with gr.Column(scale=1, min_width=300, elem_id="output-column"):
182
- gr.Markdown("πŸ“Š **Prediction Results**", style={"text-align": "center"}) # Centered heading for results
 
183
  result_output = gr.Label(
184
  num_top_classes=2,
185
  label="Classification",
@@ -187,7 +209,7 @@ with gr.Blocks(theme=theme, css=css) as iface:
187
  )
188
 
189
  # Examples Section
190
- if example_images: # Only show examples if they exist
191
  gr.Examples(
192
  examples=example_images,
193
  inputs=image_input,
@@ -201,27 +223,24 @@ with gr.Blocks(theme=theme, css=css) as iface:
201
  gr.Markdown(
202
  """
203
  ---
204
- **How it Works:**
205
  This application uses a fine-tuned [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) vision model
206
  specifically trained to differentiate between images generated by Artificial Intelligence and those created by humans.
207
 
208
- Fine tuning code available at [https://exnrt.com/blog/ai/fine-tuning-siglip2/](https://exnrt.com/blog/ai/fine-tuning-siglip2/).
209
-
210
- **Model:**
211
- * You can find the model card here: <a href='https://huggingface.co/{model_id}' target='_blank'>{model_id}</a>
212
 
 
213
  """.format(model_id=MODEL_IDENTIFIER),
214
  elem_id="app-footer"
215
  )
216
 
217
  # Connect the button click or image change to the prediction function
218
- submit_button.click(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_button")
219
- # Allow prediction on image change/upload as well (optional, can be convenient)
220
- image_input.change(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_change")
221
 
222
 
223
  # --- Launch the App ---
224
  if __name__ == "__main__":
225
  print("Launching Gradio interface...")
226
- iface.launch() # share=True to create a public link (useful for testing)
227
  print("Gradio interface launched.")
 
10
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # --- Suppress specific warnings ---
13
+ # Suppress the specific PIL warning about potential decompression bombs
14
  warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
15
+ # Suppress transformers warning about loading weights without specifying revision
16
  warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
17
 
18
 
 
28
  print("Model and processor loaded successfully.")
29
  except Exception as e:
30
  print(f"FATAL: Error loading model or processor: {e}")
31
+ # If the model fails to load, we raise an exception to stop the app
32
  raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
33
 
34
+ # --- Prediction Function ---
35
  def classify_image(image_pil):
36
  """
37
  Classifies an image as AI-generated or Human-made.
 
44
  confidence scores. Returns an empty dict if input is None.
45
  """
46
  if image_pil is None:
47
+ # Handle case where the user clears the image input
48
  print("Warning: No image provided.")
49
+ return {} # Return empty dict, Gradio Label handles this
50
 
51
  print("Processing image...")
52
  try:
53
+ # Ensure image is RGB
54
  image = image_pil.convert("RGB")
55
+
56
+ # Preprocess using the loaded processor
57
  inputs = processor(images=image, return_tensors="pt").to(DEVICE)
58
 
59
+ # Perform inference
60
  print("Running inference...")
61
  with torch.no_grad():
62
  outputs = model(**inputs)
63
  logits = outputs.logits
64
 
65
+ # Get probabilities using softmax
66
+ # outputs.logits is shape [1, num_labels], softmax over the last dim
67
+ probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image
68
 
69
+ # Create a dictionary of label -> score
70
  results = {}
71
  for i, prob in enumerate(probabilities):
72
  label = model.config.id2label[i]
 
78
  except Exception as e:
79
  print(f"Error during prediction: {e}")
80
  # Return error in the format expected by gr.Label
81
+ # Provide a user-friendly error message in the output
82
+ return {"Error": f"Processing failed. Please try again or use a different image."}
83
 
84
  # --- Define Example Images ---
85
  example_dir = "examples"
86
  example_images = []
87
+ if os.path.exists(example_dir) and os.listdir(example_dir): # Check if dir exists AND is not empty
88
  for img_name in os.listdir(example_dir):
89
  if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
90
  example_images.append(os.path.join(example_dir, img_name))
91
+ if example_images:
92
+ print(f"Found examples: {example_images}")
93
+ else:
94
+ print("No valid image files found in 'examples' directory.")
95
  else:
96
  print("No 'examples' directory found or it's empty. Examples will not be shown.")
97
 
98
 
99
  # --- Custom CSS ---
 
100
  css = """
101
  body { font-family: 'Inter', sans-serif; } /* Use a clean sans-serif font */
102
 
 
130
  #prediction-label .label-name { font-weight: bold; font-size: 1.1em; }
131
  #prediction-label .confidence { font-size: 1em; }
132
 
133
+ /* Style the results heading */
134
+ #results-heading {
135
+ text-align: center;
136
+ font-size: 1.2em; /* Slightly larger heading for results */
137
+ margin-bottom: 10px; /* Space below heading */
138
+ color: #34495e; /* Match other heading colors */
139
+ }
140
+
141
  /* Style the examples section */
142
  .gradio-container .examples-container { padding-top: 15px; }
143
  .gradio-container .examples-header { font-size: 1.1em; font-weight: bold; margin-bottom: 10px; color: #34495e; }
 
165
  """
166
 
167
  # --- Gradio Interface using Blocks and Theme ---
168
+ # Choose a theme: gr.themes.Soft(), gr.themes.Monochrome(), gr.themes.Glass(), etc.
 
169
  theme = gr.themes.Soft(
170
  primary_hue="emerald", # Color scheme based on emerald green
171
  secondary_hue="blue",
 
174
  spacing_size=gr.themes.sizes.spacing_lg, # More spacing
175
  ).set(
176
  # Further fine-tuning
177
+ body_background_fill="#f8f9fa", # Very light grey background
178
  block_radius="12px",
179
  )
180
 
 
200
  submit_button = gr.Button("πŸ” Classify Image", variant="primary") # Make button prominent
201
 
202
  with gr.Column(scale=1, min_width=300, elem_id="output-column"):
203
+ # Use elem_id and target with CSS for styling
204
+ gr.Markdown("πŸ“Š **Prediction Results**", elem_id="results-heading")
205
  result_output = gr.Label(
206
  num_top_classes=2,
207
  label="Classification",
 
209
  )
210
 
211
  # Examples Section
212
+ if example_images: # Only show examples if they exist and list is not empty
213
  gr.Examples(
214
  examples=example_images,
215
  inputs=image_input,
 
223
  gr.Markdown(
224
  """
225
  ---
 
226
  This application uses a fine-tuned [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) vision model
227
  specifically trained to differentiate between images generated by Artificial Intelligence and those created by humans.
228
 
229
+ You can find the model card here: <a href='https://huggingface.co/{model_id}' target='_blank'>{model_id}</a>
 
 
 
230
 
231
+ Fine tuning code available at [https://exnrt.com/blog/ai/fine-tuning-siglip2/](https://exnrt.com/blog/ai/fine-tuning-siglip2/).
232
  """.format(model_id=MODEL_IDENTIFIER),
233
  elem_id="app-footer"
234
  )
235
 
236
  # Connect the button click or image change to the prediction function
237
+ # Use api_name for potential API usage later
238
+ submit_button.click(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_button")
239
+ image_input.change(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_change")
240
 
241
 
242
  # --- Launch the App ---
243
  if __name__ == "__main__":
244
  print("Launching Gradio interface...")
245
+ iface.launch() # Add share=True for temporary public link if needed: iface.launch(share=True)
246
  print("Gradio interface launched.")