Ateeqq commited on
Commit
727b540
·
verified ·
1 Parent(s): b0dcf3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -44
app.py CHANGED
@@ -10,9 +10,7 @@ MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector"
10
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # --- Suppress specific warnings ---
13
- # Suppress the specific PIL warning about potential decompression bombs
14
  warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
15
- # Suppress transformers warning about loading weights without specifying revision
16
  warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
17
 
18
 
@@ -28,10 +26,9 @@ try:
28
  print("Model and processor loaded successfully.")
29
  except Exception as e:
30
  print(f"FATAL: Error loading model or processor: {e}")
31
- # If the model fails to load, we raise an exception to stop the app
32
  raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
33
 
34
- # --- Prediction Function ---
35
  def classify_image(image_pil):
36
  """
37
  Classifies an image as AI-generated or Human-made.
@@ -44,47 +41,35 @@ def classify_image(image_pil):
44
  confidence scores. Returns an empty dict if input is None.
45
  """
46
  if image_pil is None:
47
- # Handle case where the user clears the image input
48
  print("Warning: No image provided.")
49
- return {} # Return empty dict, Gradio Label handles this
50
 
51
  print("Processing image...")
52
  try:
53
- # Ensure image is RGB
54
  image = image_pil.convert("RGB")
55
-
56
- # Preprocess using the loaded processor
57
  inputs = processor(images=image, return_tensors="pt").to(DEVICE)
58
 
59
- # Perform inference
60
  print("Running inference...")
61
  with torch.no_grad():
62
  outputs = model(**inputs)
63
  logits = outputs.logits
64
 
65
- # Get probabilities using softmax
66
- # outputs.logits is shape [1, num_labels], softmax over the last dim
67
- probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image
68
 
69
- # Create a dictionary of label -> score
70
  results = {}
71
  for i, prob in enumerate(probabilities):
72
  label = model.config.id2label[i]
73
- results[label] = prob.item() # Use .item() to get Python float
74
 
75
  print(f"Prediction results: {results}")
76
  return results
77
 
78
  except Exception as e:
79
  print(f"Error during prediction: {e}")
80
- # Optionally raise a Gradio error to show it in the UI
81
- # raise gr.Error(f"Error processing image: {e}")
82
- return {"Error": f"Processing failed: {e}"} # Or return an error message
83
-
84
- # --- Gradio Interface Definition ---
85
 
86
- # Define Example Images (Optional, but recommended)
87
- # Create an 'examples' folder in your Space repo and put images there
88
  example_dir = "examples"
89
  example_images = []
90
  if os.path.exists(example_dir):
@@ -93,33 +78,150 @@ if os.path.exists(example_dir):
93
  example_images.append(os.path.join(example_dir, img_name))
94
  print(f"Found examples: {example_images}")
95
  else:
96
- print("No 'examples' directory found. Examples will not be shown.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
- # Define the Gradio interface
100
- iface = gr.Interface(
101
- fn=classify_image,
102
- inputs=gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam", "clipboard"]), # Use PIL format as input
103
- outputs=gr.Label(num_top_classes=2, label="Prediction Results"), # Use gr.Label for classification output
104
- title="AI vs Human Image Detector",
105
- description=(
106
  f"Upload an image to classify if it was likely generated by AI or created by a human. "
107
- f"Uses the `{MODEL_IDENTIFIER}` model on Hugging Face. Running on **{str(DEVICE).upper()}**."
108
- ),
109
- article=(
110
- "<div>"
111
- "<p>This tool uses a SigLIP model fine-tuned for distinguishing between AI-generated and human-made images.</p>"
112
- f"<p>Model Card: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a></p>"
113
- "<p style='text-align: center;'>App created using Gradio and Hugging Face Transformers.</p>"
114
- "</div>"
115
- ),
116
- examples=example_images if example_images else None, # Only add examples if found
117
- cache_examples= True if example_images else False, # Cache results for examples if they exist
118
- allow_flagging="never" # Or "auto" if you want users to flag issues
119
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # --- Launch the App ---
122
  if __name__ == "__main__":
123
  print("Launching Gradio interface...")
124
- iface.launch()
125
  print("Gradio interface launched.")
 
10
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # --- Suppress specific warnings ---
 
13
  warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
 
14
  warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
15
 
16
 
 
26
  print("Model and processor loaded successfully.")
27
  except Exception as e:
28
  print(f"FATAL: Error loading model or processor: {e}")
 
29
  raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
30
 
31
+ # --- Prediction Function (No changes needed) ---
32
  def classify_image(image_pil):
33
  """
34
  Classifies an image as AI-generated or Human-made.
 
41
  confidence scores. Returns an empty dict if input is None.
42
  """
43
  if image_pil is None:
 
44
  print("Warning: No image provided.")
45
+ return {}
46
 
47
  print("Processing image...")
48
  try:
 
49
  image = image_pil.convert("RGB")
 
 
50
  inputs = processor(images=image, return_tensors="pt").to(DEVICE)
51
 
 
52
  print("Running inference...")
53
  with torch.no_grad():
54
  outputs = model(**inputs)
55
  logits = outputs.logits
56
 
57
+ probabilities = torch.softmax(logits, dim=-1)[0]
 
 
58
 
 
59
  results = {}
60
  for i, prob in enumerate(probabilities):
61
  label = model.config.id2label[i]
62
+ results[label] = round(prob.item(), 4) # Round for cleaner display
63
 
64
  print(f"Prediction results: {results}")
65
  return results
66
 
67
  except Exception as e:
68
  print(f"Error during prediction: {e}")
69
+ # Return error in the format expected by gr.Label
70
+ return {"Error": f"Processing failed"}
 
 
 
71
 
72
+ # --- Define Example Images ---
 
73
  example_dir = "examples"
74
  example_images = []
75
  if os.path.exists(example_dir):
 
78
  example_images.append(os.path.join(example_dir, img_name))
79
  print(f"Found examples: {example_images}")
80
  else:
81
+ print("No 'examples' directory found or it's empty. Examples will not be shown.")
82
+
83
+
84
+ # --- Custom CSS ---
85
+ # You can experiment with different CSS here
86
+ css = """
87
+ body { font-family: 'Inter', sans-serif; } /* Use a clean sans-serif font */
88
+
89
+ /* Style the main title */
90
+ #app-title {
91
+ text-align: center;
92
+ font-weight: bold;
93
+ font-size: 2.5em; /* Larger title */
94
+ margin-bottom: 5px; /* Reduced space below title */
95
+ color: #2c3e50; /* Darker color */
96
+ }
97
+
98
+ /* Style the description */
99
+ #app-description {
100
+ text-align: center;
101
+ font-size: 1.1em;
102
+ margin-bottom: 25px; /* More space below description */
103
+ color: #576574; /* Subdued color */
104
+ }
105
+ #app-description code { /* Style model name */
106
+ font-weight: bold;
107
+ background-color: #f1f2f6;
108
+ padding: 2px 5px;
109
+ border-radius: 4px;
110
+ }
111
+ #app-description strong { /* Style device name */
112
+ color: #1abc9c; /* Highlight color for device */
113
+ }
114
+
115
+ /* Style the results area */
116
+ #prediction-label .label-name { font-weight: bold; font-size: 1.1em; }
117
+ #prediction-label .confidence { font-size: 1em; }
118
+
119
+ /* Style the examples section */
120
+ .gradio-container .examples-container { padding-top: 15px; }
121
+ .gradio-container .examples-header { font-size: 1.1em; font-weight: bold; margin-bottom: 10px; color: #34495e; }
122
+
123
+ /* Add a subtle border/shadow to input/output columns for definition */
124
+ #input-column, #output-column {
125
+ border: 1px solid #e0e0e0;
126
+ border-radius: 12px; /* More rounded corners */
127
+ padding: 20px;
128
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* Subtle shadow */
129
+ background-color: #ffffff; /* Ensure white background */
130
+ }
131
+
132
+ /* Footer styling */
133
+ #app-footer {
134
+ margin-top: 40px;
135
+ padding-top: 20px;
136
+ border-top: 1px solid #dfe6e9;
137
+ text-align: center;
138
+ font-size: 0.9em;
139
+ color: #8395a7;
140
+ }
141
+ #app-footer a { color: #3498db; text-decoration: none; }
142
+ #app-footer a:hover { text-decoration: underline; }
143
+ """
144
+
145
+ # --- Gradio Interface using Blocks and Theme ---
146
+ # Choose a theme: gr.themes.Soft(), gr.themes.Monochrome(), gr.themes.Glass()
147
+ # Or customize the default: gr.themes.Default().set(radius_size="sm", spacing_size="sm")
148
+ theme = gr.themes.Soft(
149
+ primary_hue="emerald", # Color scheme based on emerald green
150
+ secondary_hue="blue",
151
+ neutral_hue="slate",
152
+ radius_size=gr.themes.sizes.radius_lg, # Larger corner radius
153
+ spacing_size=gr.themes.sizes.spacing_lg, # More spacing
154
+ ).set(
155
+ # Further fine-tuning
156
+ body_background_fill="#f1f2f6", # Light grey background
157
+ block_radius="12px",
158
+ )
159
 
160
 
161
+ with gr.Blocks(theme=theme, css=css) as iface:
162
+ # Title and Description using Markdown for better formatting
163
+ gr.Markdown("# AI vs Human Image Detector", elem_id="app-title")
164
+ gr.Markdown(
 
 
 
165
  f"Upload an image to classify if it was likely generated by AI or created by a human. "
166
+ f"Uses the `{MODEL_IDENTIFIER}` model. Running on **{str(DEVICE).upper()}**.",
167
+ elem_id="app-description"
168
+ )
169
+
170
+ # Main layout with Input and Output side-by-side
171
+ with gr.Row(variant='panel'): # 'panel' adds a light border/background
172
+ with gr.Column(scale=1, min_width=300, elem_id="input-column"):
173
+ image_input = gr.Image(
174
+ type="pil",
175
+ label="🖼️ Upload Your Image",
176
+ sources=["upload", "webcam", "clipboard"],
177
+ height=400, # Adjust height as needed
178
+ )
179
+ submit_button = gr.Button("🔍 Classify Image", variant="primary") # Make button prominent
180
+
181
+ with gr.Column(scale=1, min_width=300, elem_id="output-column"):
182
+ gr.Markdown("📊 **Prediction Results**", style={"text-align": "center"}) # Centered heading for results
183
+ result_output = gr.Label(
184
+ num_top_classes=2,
185
+ label="Classification",
186
+ elem_id="prediction-label"
187
+ )
188
+
189
+ # Examples Section
190
+ if example_images: # Only show examples if they exist
191
+ gr.Examples(
192
+ examples=example_images,
193
+ inputs=image_input,
194
+ outputs=result_output,
195
+ fn=classify_image,
196
+ cache_examples=True, # Caching is good for static examples
197
+ label="✨ Click an Example to Try!"
198
+ )
199
+
200
+ # Footer / Article section
201
+ gr.Markdown(
202
+ """
203
+ ---
204
+ **How it Works:**
205
+ This application uses a fine-tuned [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) vision model
206
+ specifically trained to differentiate between images generated by Artificial Intelligence and those created by humans.
207
+
208
+ Fine tuning code available at [https://exnrt.com/blog/ai/fine-tuning-siglip2/](https://exnrt.com/blog/ai/fine-tuning-siglip2/).
209
+
210
+ **Model:**
211
+ * You can find the model card here: <a href='https://huggingface.co/{model_id}' target='_blank'>{model_id}</a>
212
+
213
+ """.format(model_id=MODEL_IDENTIFIER),
214
+ elem_id="app-footer"
215
+ )
216
+
217
+ # Connect the button click or image change to the prediction function
218
+ submit_button.click(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_button")
219
+ # Allow prediction on image change/upload as well (optional, can be convenient)
220
+ image_input.change(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_change")
221
+
222
 
223
  # --- Launch the App ---
224
  if __name__ == "__main__":
225
  print("Launching Gradio interface...")
226
+ iface.launch() # share=True to create a public link (useful for testing)
227
  print("Gradio interface launched.")