Spaces:
Running
Running
Upload 6 files
Browse files- app.py +169 -0
- best-model-hf/config.json +33 -0
- best-model-hf/model.safetensors +3 -0
- best-model-hf/preprocessor_config.json +23 -0
- best-model-hf/training_args.bin +3 -0
- requirements.txt +0 -0
app.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#@title 3. Load Model from HF Directory and Launch Gradio Interface
|
2 |
+
|
3 |
+
# --- Imports ---
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
import os
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from transformers import AutoFeatureExtractor, ViTForImageClassification # Still need ViT class
|
10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
11 |
+
from torch import device, cuda
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
# --- Configuration ---
|
15 |
+
# REMOVED: pkl_file_path = '/content/finetune_vit_model.pkl'
|
16 |
+
# ADDED: Path to the directory created by unzipping
|
17 |
+
hf_model_directory = './best-model-hf'
|
18 |
+
# IMPORTANT: This MUST match the base model used for fine-tuning to load the correct feature extractor
|
19 |
+
model_checkpoint = "google/vit-base-patch16-224"
|
20 |
+
device_to_use = device('cuda' if cuda.is_available() else 'cpu')
|
21 |
+
print(f"Using device: {device_to_use}")
|
22 |
+
|
23 |
+
# --- Global variables for loaded components ---
|
24 |
+
inference_model = None
|
25 |
+
inference_feature_extractor = None
|
26 |
+
inference_transforms = None
|
27 |
+
inference_id2label = None
|
28 |
+
num_labels = 0
|
29 |
+
|
30 |
+
# --- Load Feature Extractor (Needed for Preprocessing) ---
|
31 |
+
# (This part remains the same as it's needed for transforms regardless of model loading method)
|
32 |
+
try:
|
33 |
+
print(f"Loading feature extractor for: {model_checkpoint}")
|
34 |
+
inference_feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
35 |
+
print("Feature extractor loaded successfully.")
|
36 |
+
|
37 |
+
# --- Define Image Transforms (Must match inference transforms from training) ---
|
38 |
+
normalize = Normalize(mean=inference_feature_extractor.image_mean, std=inference_feature_extractor.image_std)
|
39 |
+
if isinstance(inference_feature_extractor.size, dict):
|
40 |
+
image_size = inference_feature_extractor.size.get('shortest_edge', inference_feature_extractor.size.get('height', 224))
|
41 |
+
else:
|
42 |
+
image_size = inference_feature_extractor.size
|
43 |
+
print(f"Using image size: {image_size}")
|
44 |
+
|
45 |
+
inference_transforms = Compose([
|
46 |
+
Resize(image_size),
|
47 |
+
CenterCrop(image_size),
|
48 |
+
ToTensor(),
|
49 |
+
normalize,
|
50 |
+
])
|
51 |
+
print("Inference transforms defined.")
|
52 |
+
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Error loading feature extractor or defining transforms: {e}")
|
55 |
+
print("Cannot proceed without feature extractor and transforms.")
|
56 |
+
raise SystemExit("Feature extractor/transforms loading failed.")
|
57 |
+
|
58 |
+
|
59 |
+
# --- Load the Fine-Tuned Model from Hugging Face Save Directory --- ## MODIFIED BLOCK ##
|
60 |
+
if not os.path.isdir(hf_model_directory):
|
61 |
+
print(f"ERROR: Hugging Face model directory not found at '{hf_model_directory}'.")
|
62 |
+
print("Please ensure you uploaded the zip file and ran the 'Unzip' cell successfully.")
|
63 |
+
inference_model = None # Ensure model is None if dir not found
|
64 |
+
else:
|
65 |
+
print(f"Attempting to load model from directory: {hf_model_directory}")
|
66 |
+
try:
|
67 |
+
# Load the model using from_pretrained with the directory path
|
68 |
+
inference_model = ViTForImageClassification.from_pretrained(hf_model_directory)
|
69 |
+
|
70 |
+
# --- Post-Load Setup ---
|
71 |
+
inference_model.to(device_to_use)
|
72 |
+
inference_model.eval() # Set model to evaluation mode
|
73 |
+
print("Model loaded successfully from directory and moved to device.")
|
74 |
+
|
75 |
+
# Try to get label mapping from the loaded model's config (this usually works well with from_pretrained)
|
76 |
+
if hasattr(inference_model, 'config') and hasattr(inference_model.config, 'id2label'):
|
77 |
+
inference_id2label = inference_model.config.id2label
|
78 |
+
# Ensure keys are integers if loaded from JSON/dict
|
79 |
+
inference_id2label = {int(k): v for k, v in inference_id2label.items()}
|
80 |
+
num_labels = len(inference_id2label)
|
81 |
+
print(f"Loaded id2label mapping from model config: {inference_id2label}")
|
82 |
+
print(f"Number of labels: {num_labels}")
|
83 |
+
else:
|
84 |
+
# Fallback if id2label isn't in the config for some reason
|
85 |
+
print("WARNING: Could not find 'id2label' in the loaded model's config.")
|
86 |
+
# --- !! MANUALLY DEFINE LABELS HERE IF NEEDED !! ---
|
87 |
+
# Example: Replace with your actual labels and order
|
88 |
+
inference_id2label = {0: 'fake', 1: 'real'} # Make sure this matches your training
|
89 |
+
num_labels = len(inference_id2label)
|
90 |
+
print(f"Using manually defined id2label: {inference_id2label}")
|
91 |
+
# -----------------------------------------------------
|
92 |
+
|
93 |
+
if num_labels == 0:
|
94 |
+
print("ERROR: Number of labels is zero. Cannot proceed.")
|
95 |
+
inference_model = None # Prevent Gradio launch
|
96 |
+
|
97 |
+
except Exception as e:
|
98 |
+
# Catch errors during from_pretrained (e.g., missing files, config errors)
|
99 |
+
print(f"An unexpected error occurred loading the model from directory: {e}")
|
100 |
+
inference_model = None # Ensure model is None on error
|
101 |
+
## --- END OF MODIFIED BLOCK --- ##
|
102 |
+
|
103 |
+
|
104 |
+
# --- Define the Prediction Function for Gradio ---
|
105 |
+
# (This function remains the same)
|
106 |
+
def predict(image: Image.Image):
|
107 |
+
"""
|
108 |
+
Takes a PIL image, preprocesses it, and returns label probabilities.
|
109 |
+
"""
|
110 |
+
# Ensure model and necessary components are loaded
|
111 |
+
if inference_model is None:
|
112 |
+
return {"Error": "Model not loaded. Please check loading logs."}
|
113 |
+
if inference_transforms is None:
|
114 |
+
return {"Error": "Inference transforms not defined."}
|
115 |
+
if inference_id2label is None:
|
116 |
+
return {"Error": "Label mapping (id2label) not available."}
|
117 |
+
if image is None:
|
118 |
+
return None # Gradio handles None input gracefully sometimes
|
119 |
+
|
120 |
+
try:
|
121 |
+
# Preprocess the image
|
122 |
+
image = image.convert("RGB") # Ensure 3 channels
|
123 |
+
pixel_values = inference_transforms(image).unsqueeze(0).to(device_to_use)
|
124 |
+
|
125 |
+
# Perform inference
|
126 |
+
with torch.no_grad():
|
127 |
+
outputs = inference_model(pixel_values=pixel_values)
|
128 |
+
logits = outputs.logits
|
129 |
+
|
130 |
+
# Get probabilities and format output
|
131 |
+
probabilities = F.softmax(logits, dim=-1)[0] # Get probabilities for the first (only) image
|
132 |
+
confidences = {inference_id2label[i]: float(prob) for i, prob in enumerate(probabilities)}
|
133 |
+
return confidences
|
134 |
+
|
135 |
+
except Exception as e:
|
136 |
+
print(f"Error during prediction: {e}")
|
137 |
+
# Return error in a format Gradio Label can display
|
138 |
+
return {"Error": f"Prediction failed: {str(e)}"}
|
139 |
+
|
140 |
+
|
141 |
+
# --- Create and Launch the Gradio Interface ---
|
142 |
+
# (This part remains the same, but title/description updated slightly)
|
143 |
+
if inference_model and inference_id2label and num_labels > 0:
|
144 |
+
print("\nSetting up Gradio Interface...")
|
145 |
+
try:
|
146 |
+
iface = gr.Interface(
|
147 |
+
fn=predict,
|
148 |
+
inputs=gr.Image(type="pil", label="Upload Face Image"),
|
149 |
+
outputs=gr.Label(num_top_classes=num_labels, label="Prediction (Real/Fake)"),
|
150 |
+
# Updated Title/Description
|
151 |
+
title="Real vs. Fake Face Detector (Loaded from HF Directory)",
|
152 |
+
description="Upload an image of a face to classify it as real or fake using the fine-tuned ViT model loaded from the 'best-model-hf' directory.",
|
153 |
+
# examples=[...] # Optional: Add example image paths if you upload some
|
154 |
+
)
|
155 |
+
|
156 |
+
print("Launching Gradio interface...")
|
157 |
+
print("Access the interface through the public URL generated below (if sharing is enabled) or the local URL.")
|
158 |
+
iface.launch(share=True, debug=True)
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
print(f"Error creating or launching Gradio interface: {e}")
|
162 |
+
|
163 |
+
else:
|
164 |
+
print("\nCould not launch Gradio interface because the model or label mapping failed to load.")
|
165 |
+
print("Please check the error messages above.")
|
166 |
+
|
167 |
+
# Keep the cell running to keep the Gradio interface active
|
168 |
+
print("\nGradio setup finished. Interface should be running or an error reported above.")
|
169 |
+
print("Stop this cell execution in Colab to shut down the Gradio server.")
|
best-model-hf/config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"ViTForImageClassification"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.0,
|
6 |
+
"encoder_stride": 16,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.0,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"id2label": {
|
11 |
+
"0": "training_fake",
|
12 |
+
"1": "training_real"
|
13 |
+
},
|
14 |
+
"image_size": 224,
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"intermediate_size": 3072,
|
17 |
+
"label2id": {
|
18 |
+
"training_fake": 0,
|
19 |
+
"training_real": 1
|
20 |
+
},
|
21 |
+
"layer_norm_eps": 1e-12,
|
22 |
+
"model_type": "vit",
|
23 |
+
"num_attention_heads": 12,
|
24 |
+
"num_channels": 3,
|
25 |
+
"num_hidden_layers": 12,
|
26 |
+
"patch_size": 16,
|
27 |
+
"pooler_act": "tanh",
|
28 |
+
"pooler_output_size": 768,
|
29 |
+
"problem_type": "single_label_classification",
|
30 |
+
"qkv_bias": true,
|
31 |
+
"torch_dtype": "float32",
|
32 |
+
"transformers_version": "4.51.1"
|
33 |
+
}
|
best-model-hf/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:30f9d1285ccca5943c7a5d7e077a84aac64bff767b531606722cec680b67d8b0
|
3 |
+
size 343223968
|
best-model-hf/preprocessor_config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_convert_rgb": null,
|
3 |
+
"do_normalize": true,
|
4 |
+
"do_rescale": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"image_mean": [
|
7 |
+
0.5,
|
8 |
+
0.5,
|
9 |
+
0.5
|
10 |
+
],
|
11 |
+
"image_processor_type": "ViTFeatureExtractor",
|
12 |
+
"image_std": [
|
13 |
+
0.5,
|
14 |
+
0.5,
|
15 |
+
0.5
|
16 |
+
],
|
17 |
+
"resample": 2,
|
18 |
+
"rescale_factor": 0.00392156862745098,
|
19 |
+
"size": {
|
20 |
+
"height": 224,
|
21 |
+
"width": 224
|
22 |
+
}
|
23 |
+
}
|
best-model-hf/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:701595fdeadb1794aea348ac18c140fc7acf99a103c1f1030642b7f487ee4a71
|
3 |
+
size 5368
|
requirements.txt
ADDED
File without changes
|