Spaces:
Running
Running
Upload main.py
Browse files
main.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#@title 3. Load Model from HF Directory and Launch Gradio Interface
|
2 |
+
|
3 |
+
# --- Imports ---
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
import os
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from transformers import AutoFeatureExtractor, ViTForImageClassification
|
10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
11 |
+
from torch import device, cuda
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
# --- Configuration ---
|
15 |
+
hf_model_directory = 'best-model-hf' # Corrected path (no leading dot)
|
16 |
+
model_checkpoint = "google/vit-base-patch16-224"
|
17 |
+
device_to_use = device('cuda' if cuda.is_available() else 'cpu')
|
18 |
+
print(f"Using device: {device_to_use}")
|
19 |
+
|
20 |
+
# --- Predictor Class ---
|
21 |
+
class ImagePredictor:
|
22 |
+
def __init__(self, model_dir, base_checkpoint, device):
|
23 |
+
self.model_dir = model_dir
|
24 |
+
self.base_checkpoint = base_checkpoint
|
25 |
+
self.device = device
|
26 |
+
self.model = None
|
27 |
+
self.feature_extractor = None
|
28 |
+
self.transforms = None
|
29 |
+
self.id2label = None
|
30 |
+
self.num_labels = 0
|
31 |
+
self._load_resources() # Load everything during initialization
|
32 |
+
|
33 |
+
def _load_resources(self):
|
34 |
+
print("--- Loading Predictor Resources ---")
|
35 |
+
# --- Load Feature Extractor (Needed for Preprocessing) ---
|
36 |
+
try:
|
37 |
+
print(f"Loading feature extractor for: {self.base_checkpoint}")
|
38 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.base_checkpoint)
|
39 |
+
print("Feature extractor loaded.")
|
40 |
+
|
41 |
+
# --- Define Image Transforms ---
|
42 |
+
normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
|
43 |
+
if isinstance(self.feature_extractor.size, dict):
|
44 |
+
image_size = self.feature_extractor.size.get('shortest_edge', self.feature_extractor.size.get('height', 224))
|
45 |
+
else:
|
46 |
+
image_size = self.feature_extractor.size
|
47 |
+
print(f"Using image size: {image_size}")
|
48 |
+
|
49 |
+
self.transforms = Compose([
|
50 |
+
Resize(image_size),
|
51 |
+
CenterCrop(image_size),
|
52 |
+
ToTensor(),
|
53 |
+
normalize,
|
54 |
+
])
|
55 |
+
print("Inference transforms defined.")
|
56 |
+
|
57 |
+
except Exception as e:
|
58 |
+
print(f"FATAL: Error loading feature extractor or defining transforms: {e}")
|
59 |
+
# Re-raise to prevent using a partially initialized object
|
60 |
+
raise RuntimeError("Feature extractor/transforms loading failed.") from e
|
61 |
+
|
62 |
+
# --- Load the Fine-Tuned Model ---
|
63 |
+
if not os.path.isdir(self.model_dir):
|
64 |
+
print(f"FATAL: Model directory not found at '{self.model_dir}'.")
|
65 |
+
raise FileNotFoundError(f"Model directory not found: {self.model_dir}")
|
66 |
+
|
67 |
+
print(f"Attempting to load model from directory: {self.model_dir}")
|
68 |
+
try:
|
69 |
+
self.model = ViTForImageClassification.from_pretrained(self.model_dir)
|
70 |
+
self.model.to(self.device)
|
71 |
+
self.model.eval() # Set model to evaluation mode
|
72 |
+
print("Model loaded successfully from directory and moved to device.")
|
73 |
+
|
74 |
+
# --- Load Label Mapping ---
|
75 |
+
if hasattr(self.model, 'config') and hasattr(self.model.config, 'id2label'):
|
76 |
+
self.id2label = {int(k): v for k, v in self.model.config.id2label.items()}
|
77 |
+
self.num_labels = len(self.id2label)
|
78 |
+
print(f"Loaded id2label mapping from model config: {self.id2label}")
|
79 |
+
print(f"Number of labels: {self.num_labels}")
|
80 |
+
else:
|
81 |
+
print("WARNING: Could not find 'id2label' in the loaded model's config.")
|
82 |
+
# --- !! MANUALLY DEFINE FALLBACK IF NEEDED !! ---
|
83 |
+
self.id2label = {0: 'fake', 1: 'real'} # ENSURE THIS MATCHES TRAINING
|
84 |
+
self.num_labels = len(self.id2label)
|
85 |
+
print(f"Using manually defined id2label: {self.id2label}")
|
86 |
+
# ----------------------------------------------
|
87 |
+
|
88 |
+
if self.num_labels == 0:
|
89 |
+
raise ValueError("Number of labels is zero after loading.")
|
90 |
+
|
91 |
+
print("--- Predictor Resources Loaded Successfully ---")
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
print(f"FATAL: An unexpected error occurred loading the model: {e}")
|
95 |
+
# Reset model attribute to indicate failure clearly
|
96 |
+
self.model = None
|
97 |
+
# Re-raise to prevent using a partially initialized object
|
98 |
+
raise RuntimeError("Model loading failed.") from e
|
99 |
+
|
100 |
+
# --- Prediction Method ---
|
101 |
+
def predict(self, image: Image.Image):
|
102 |
+
"""
|
103 |
+
Takes a PIL image, preprocesses it, and returns label probabilities.
|
104 |
+
Uses the loaded instance attributes (self.model, self.transforms, etc.)
|
105 |
+
"""
|
106 |
+
# Check if initialization succeeded (should be caught by __init__ exceptions, but good practice)
|
107 |
+
if self.model is None or self.transforms is None or self.id2label is None:
|
108 |
+
return {"Error": "Predictor not initialized correctly. Check logs."}
|
109 |
+
if image is None:
|
110 |
+
return None # Gradio handles None input
|
111 |
+
|
112 |
+
try:
|
113 |
+
# Preprocess the image
|
114 |
+
image = image.convert("RGB") # Ensure 3 channels
|
115 |
+
pixel_values = self.transforms(image).unsqueeze(0).to(self.device)
|
116 |
+
|
117 |
+
# Perform inference
|
118 |
+
with torch.no_grad():
|
119 |
+
outputs = self.model(pixel_values=pixel_values)
|
120 |
+
logits = outputs.logits
|
121 |
+
|
122 |
+
# Get probabilities and format output
|
123 |
+
probabilities = F.softmax(logits, dim=-1)[0] # Get probabilities for the first image
|
124 |
+
confidences = {self.id2label[i]: float(prob) for i, prob in enumerate(probabilities)}
|
125 |
+
return confidences
|
126 |
+
|
127 |
+
except Exception as e:
|
128 |
+
print(f"Error during prediction: {e}")
|
129 |
+
return {"Error": f"Prediction failed: {str(e)}"}
|
130 |
+
|
131 |
+
|
132 |
+
# --- Main Execution Logic ---
|
133 |
+
predictor = None
|
134 |
+
try:
|
135 |
+
# Instantiate the predictor ONCE globally
|
136 |
+
# This loads the model, tokenizer, transforms, etc. immediately
|
137 |
+
predictor = ImagePredictor(
|
138 |
+
model_dir=hf_model_directory,
|
139 |
+
base_checkpoint=model_checkpoint,
|
140 |
+
device=device_to_use
|
141 |
+
)
|
142 |
+
except Exception as e:
|
143 |
+
print(f"Failed to initialize ImagePredictor: {e}")
|
144 |
+
# predictor remains None
|
145 |
+
|
146 |
+
|
147 |
+
# --- Create and Launch the Gradio Interface ---
|
148 |
+
if predictor and predictor.model: # Check if predictor initialized successfully
|
149 |
+
print("\nSetting up Gradio Interface...")
|
150 |
+
try:
|
151 |
+
iface = gr.Interface(
|
152 |
+
# Pass the INSTANCE METHOD to fn
|
153 |
+
fn=predictor.predict,
|
154 |
+
inputs=gr.Image(type="pil", label="Upload Face Image"),
|
155 |
+
outputs=gr.Label(num_top_classes=predictor.num_labels, label="Prediction (Real/Fake)"),
|
156 |
+
title="Real vs. Fake Face Detector",
|
157 |
+
description=f"Upload an image of a face to classify it using the fine-tuned ViT model loaded from the '{hf_model_directory}' directory.",
|
158 |
+
)
|
159 |
+
|
160 |
+
print("Launching Gradio interface...")
|
161 |
+
# Set share=True as requested
|
162 |
+
iface.launch(share=True, debug=True, show_error=True)
|
163 |
+
|
164 |
+
except Exception as e:
|
165 |
+
print(f"Error creating or launching Gradio interface: {e}")
|
166 |
+
|
167 |
+
else:
|
168 |
+
print("\nCould not launch Gradio interface because the Predictor failed to initialize.")
|
169 |
+
print("Please check the error messages above.")
|
170 |
+
|
171 |
+
|
172 |
+
# Optional: Add message for Colab/persistent running if needed
|
173 |
+
print("\nGradio setup finished. Interface should be running or an error reported above.")
|
174 |
+
# print("Stop this cell execution in Colab to shut down the Gradio server.")
|