ehtyalee commited on
Commit
9abf394
·
verified ·
1 Parent(s): 69fd3e6

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +174 -0
main.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title 3. Load Model from HF Directory and Launch Gradio Interface
2
+
3
+ # --- Imports ---
4
+ import torch
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import os
8
+ import torch.nn.functional as F
9
+ from transformers import AutoFeatureExtractor, ViTForImageClassification
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from torch import device, cuda
12
+ import numpy as np
13
+
14
+ # --- Configuration ---
15
+ hf_model_directory = 'best-model-hf' # Corrected path (no leading dot)
16
+ model_checkpoint = "google/vit-base-patch16-224"
17
+ device_to_use = device('cuda' if cuda.is_available() else 'cpu')
18
+ print(f"Using device: {device_to_use}")
19
+
20
+ # --- Predictor Class ---
21
+ class ImagePredictor:
22
+ def __init__(self, model_dir, base_checkpoint, device):
23
+ self.model_dir = model_dir
24
+ self.base_checkpoint = base_checkpoint
25
+ self.device = device
26
+ self.model = None
27
+ self.feature_extractor = None
28
+ self.transforms = None
29
+ self.id2label = None
30
+ self.num_labels = 0
31
+ self._load_resources() # Load everything during initialization
32
+
33
+ def _load_resources(self):
34
+ print("--- Loading Predictor Resources ---")
35
+ # --- Load Feature Extractor (Needed for Preprocessing) ---
36
+ try:
37
+ print(f"Loading feature extractor for: {self.base_checkpoint}")
38
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.base_checkpoint)
39
+ print("Feature extractor loaded.")
40
+
41
+ # --- Define Image Transforms ---
42
+ normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
43
+ if isinstance(self.feature_extractor.size, dict):
44
+ image_size = self.feature_extractor.size.get('shortest_edge', self.feature_extractor.size.get('height', 224))
45
+ else:
46
+ image_size = self.feature_extractor.size
47
+ print(f"Using image size: {image_size}")
48
+
49
+ self.transforms = Compose([
50
+ Resize(image_size),
51
+ CenterCrop(image_size),
52
+ ToTensor(),
53
+ normalize,
54
+ ])
55
+ print("Inference transforms defined.")
56
+
57
+ except Exception as e:
58
+ print(f"FATAL: Error loading feature extractor or defining transforms: {e}")
59
+ # Re-raise to prevent using a partially initialized object
60
+ raise RuntimeError("Feature extractor/transforms loading failed.") from e
61
+
62
+ # --- Load the Fine-Tuned Model ---
63
+ if not os.path.isdir(self.model_dir):
64
+ print(f"FATAL: Model directory not found at '{self.model_dir}'.")
65
+ raise FileNotFoundError(f"Model directory not found: {self.model_dir}")
66
+
67
+ print(f"Attempting to load model from directory: {self.model_dir}")
68
+ try:
69
+ self.model = ViTForImageClassification.from_pretrained(self.model_dir)
70
+ self.model.to(self.device)
71
+ self.model.eval() # Set model to evaluation mode
72
+ print("Model loaded successfully from directory and moved to device.")
73
+
74
+ # --- Load Label Mapping ---
75
+ if hasattr(self.model, 'config') and hasattr(self.model.config, 'id2label'):
76
+ self.id2label = {int(k): v for k, v in self.model.config.id2label.items()}
77
+ self.num_labels = len(self.id2label)
78
+ print(f"Loaded id2label mapping from model config: {self.id2label}")
79
+ print(f"Number of labels: {self.num_labels}")
80
+ else:
81
+ print("WARNING: Could not find 'id2label' in the loaded model's config.")
82
+ # --- !! MANUALLY DEFINE FALLBACK IF NEEDED !! ---
83
+ self.id2label = {0: 'fake', 1: 'real'} # ENSURE THIS MATCHES TRAINING
84
+ self.num_labels = len(self.id2label)
85
+ print(f"Using manually defined id2label: {self.id2label}")
86
+ # ----------------------------------------------
87
+
88
+ if self.num_labels == 0:
89
+ raise ValueError("Number of labels is zero after loading.")
90
+
91
+ print("--- Predictor Resources Loaded Successfully ---")
92
+
93
+ except Exception as e:
94
+ print(f"FATAL: An unexpected error occurred loading the model: {e}")
95
+ # Reset model attribute to indicate failure clearly
96
+ self.model = None
97
+ # Re-raise to prevent using a partially initialized object
98
+ raise RuntimeError("Model loading failed.") from e
99
+
100
+ # --- Prediction Method ---
101
+ def predict(self, image: Image.Image):
102
+ """
103
+ Takes a PIL image, preprocesses it, and returns label probabilities.
104
+ Uses the loaded instance attributes (self.model, self.transforms, etc.)
105
+ """
106
+ # Check if initialization succeeded (should be caught by __init__ exceptions, but good practice)
107
+ if self.model is None or self.transforms is None or self.id2label is None:
108
+ return {"Error": "Predictor not initialized correctly. Check logs."}
109
+ if image is None:
110
+ return None # Gradio handles None input
111
+
112
+ try:
113
+ # Preprocess the image
114
+ image = image.convert("RGB") # Ensure 3 channels
115
+ pixel_values = self.transforms(image).unsqueeze(0).to(self.device)
116
+
117
+ # Perform inference
118
+ with torch.no_grad():
119
+ outputs = self.model(pixel_values=pixel_values)
120
+ logits = outputs.logits
121
+
122
+ # Get probabilities and format output
123
+ probabilities = F.softmax(logits, dim=-1)[0] # Get probabilities for the first image
124
+ confidences = {self.id2label[i]: float(prob) for i, prob in enumerate(probabilities)}
125
+ return confidences
126
+
127
+ except Exception as e:
128
+ print(f"Error during prediction: {e}")
129
+ return {"Error": f"Prediction failed: {str(e)}"}
130
+
131
+
132
+ # --- Main Execution Logic ---
133
+ predictor = None
134
+ try:
135
+ # Instantiate the predictor ONCE globally
136
+ # This loads the model, tokenizer, transforms, etc. immediately
137
+ predictor = ImagePredictor(
138
+ model_dir=hf_model_directory,
139
+ base_checkpoint=model_checkpoint,
140
+ device=device_to_use
141
+ )
142
+ except Exception as e:
143
+ print(f"Failed to initialize ImagePredictor: {e}")
144
+ # predictor remains None
145
+
146
+
147
+ # --- Create and Launch the Gradio Interface ---
148
+ if predictor and predictor.model: # Check if predictor initialized successfully
149
+ print("\nSetting up Gradio Interface...")
150
+ try:
151
+ iface = gr.Interface(
152
+ # Pass the INSTANCE METHOD to fn
153
+ fn=predictor.predict,
154
+ inputs=gr.Image(type="pil", label="Upload Face Image"),
155
+ outputs=gr.Label(num_top_classes=predictor.num_labels, label="Prediction (Real/Fake)"),
156
+ title="Real vs. Fake Face Detector",
157
+ description=f"Upload an image of a face to classify it using the fine-tuned ViT model loaded from the '{hf_model_directory}' directory.",
158
+ )
159
+
160
+ print("Launching Gradio interface...")
161
+ # Set share=True as requested
162
+ iface.launch(share=True, debug=True, show_error=True)
163
+
164
+ except Exception as e:
165
+ print(f"Error creating or launching Gradio interface: {e}")
166
+
167
+ else:
168
+ print("\nCould not launch Gradio interface because the Predictor failed to initialize.")
169
+ print("Please check the error messages above.")
170
+
171
+
172
+ # Optional: Add message for Colab/persistent running if needed
173
+ print("\nGradio setup finished. Interface should be running or an error reported above.")
174
+ # print("Stop this cell execution in Colab to shut down the Gradio server.")