ehtyalee commited on
Commit
751f405
·
verified ·
1 Parent(s): ffc2685

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title 3. Load Model from HF Directory and Launch Gradio Interface
2
+
3
+ # --- Imports ---
4
+ import torch
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import os
8
+ import torch.nn.functional as F
9
+ from transformers import AutoFeatureExtractor, ViTForImageClassification # Still need ViT class
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from torch import device, cuda
12
+ import numpy as np
13
+
14
+ # --- Configuration ---
15
+ # REMOVED: pkl_file_path = '/content/finetune_vit_model.pkl'
16
+ # ADDED: Path to the directory created by unzipping
17
+ hf_model_directory = './best-model-hf'
18
+ # IMPORTANT: This MUST match the base model used for fine-tuning to load the correct feature extractor
19
+ model_checkpoint = "google/vit-base-patch16-224"
20
+ device_to_use = device('cuda' if cuda.is_available() else 'cpu')
21
+ print(f"Using device: {device_to_use}")
22
+
23
+ # --- Global variables for loaded components ---
24
+ inference_model = None
25
+ inference_feature_extractor = None
26
+ inference_transforms = None
27
+ inference_id2label = None
28
+ num_labels = 0
29
+
30
+ # --- Load Feature Extractor (Needed for Preprocessing) ---
31
+ # (This part remains the same as it's needed for transforms regardless of model loading method)
32
+ try:
33
+ print(f"Loading feature extractor for: {model_checkpoint}")
34
+ inference_feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
35
+ print("Feature extractor loaded successfully.")
36
+
37
+ # --- Define Image Transforms (Must match inference transforms from training) ---
38
+ normalize = Normalize(mean=inference_feature_extractor.image_mean, std=inference_feature_extractor.image_std)
39
+ if isinstance(inference_feature_extractor.size, dict):
40
+ image_size = inference_feature_extractor.size.get('shortest_edge', inference_feature_extractor.size.get('height', 224))
41
+ else:
42
+ image_size = inference_feature_extractor.size
43
+ print(f"Using image size: {image_size}")
44
+
45
+ inference_transforms = Compose([
46
+ Resize(image_size),
47
+ CenterCrop(image_size),
48
+ ToTensor(),
49
+ normalize,
50
+ ])
51
+ print("Inference transforms defined.")
52
+
53
+ except Exception as e:
54
+ print(f"Error loading feature extractor or defining transforms: {e}")
55
+ print("Cannot proceed without feature extractor and transforms.")
56
+ raise SystemExit("Feature extractor/transforms loading failed.")
57
+
58
+
59
+ # --- Load the Fine-Tuned Model from Hugging Face Save Directory --- ## MODIFIED BLOCK ##
60
+ if not os.path.isdir(hf_model_directory):
61
+ print(f"ERROR: Hugging Face model directory not found at '{hf_model_directory}'.")
62
+ print("Please ensure you uploaded the zip file and ran the 'Unzip' cell successfully.")
63
+ inference_model = None # Ensure model is None if dir not found
64
+ else:
65
+ print(f"Attempting to load model from directory: {hf_model_directory}")
66
+ try:
67
+ # Load the model using from_pretrained with the directory path
68
+ inference_model = ViTForImageClassification.from_pretrained(hf_model_directory)
69
+
70
+ # --- Post-Load Setup ---
71
+ inference_model.to(device_to_use)
72
+ inference_model.eval() # Set model to evaluation mode
73
+ print("Model loaded successfully from directory and moved to device.")
74
+
75
+ # Try to get label mapping from the loaded model's config (this usually works well with from_pretrained)
76
+ if hasattr(inference_model, 'config') and hasattr(inference_model.config, 'id2label'):
77
+ inference_id2label = inference_model.config.id2label
78
+ # Ensure keys are integers if loaded from JSON/dict
79
+ inference_id2label = {int(k): v for k, v in inference_id2label.items()}
80
+ num_labels = len(inference_id2label)
81
+ print(f"Loaded id2label mapping from model config: {inference_id2label}")
82
+ print(f"Number of labels: {num_labels}")
83
+ else:
84
+ # Fallback if id2label isn't in the config for some reason
85
+ print("WARNING: Could not find 'id2label' in the loaded model's config.")
86
+ # --- !! MANUALLY DEFINE LABELS HERE IF NEEDED !! ---
87
+ # Example: Replace with your actual labels and order
88
+ inference_id2label = {0: 'fake', 1: 'real'} # Make sure this matches your training
89
+ num_labels = len(inference_id2label)
90
+ print(f"Using manually defined id2label: {inference_id2label}")
91
+ # -----------------------------------------------------
92
+
93
+ if num_labels == 0:
94
+ print("ERROR: Number of labels is zero. Cannot proceed.")
95
+ inference_model = None # Prevent Gradio launch
96
+
97
+ except Exception as e:
98
+ # Catch errors during from_pretrained (e.g., missing files, config errors)
99
+ print(f"An unexpected error occurred loading the model from directory: {e}")
100
+ inference_model = None # Ensure model is None on error
101
+ ## --- END OF MODIFIED BLOCK --- ##
102
+
103
+
104
+ # --- Define the Prediction Function for Gradio ---
105
+ # (This function remains the same)
106
+ def predict(image: Image.Image):
107
+ """
108
+ Takes a PIL image, preprocesses it, and returns label probabilities.
109
+ """
110
+ # Ensure model and necessary components are loaded
111
+ if inference_model is None:
112
+ return {"Error": "Model not loaded. Please check loading logs."}
113
+ if inference_transforms is None:
114
+ return {"Error": "Inference transforms not defined."}
115
+ if inference_id2label is None:
116
+ return {"Error": "Label mapping (id2label) not available."}
117
+ if image is None:
118
+ return None # Gradio handles None input gracefully sometimes
119
+
120
+ try:
121
+ # Preprocess the image
122
+ image = image.convert("RGB") # Ensure 3 channels
123
+ pixel_values = inference_transforms(image).unsqueeze(0).to(device_to_use)
124
+
125
+ # Perform inference
126
+ with torch.no_grad():
127
+ outputs = inference_model(pixel_values=pixel_values)
128
+ logits = outputs.logits
129
+
130
+ # Get probabilities and format output
131
+ probabilities = F.softmax(logits, dim=-1)[0] # Get probabilities for the first (only) image
132
+ confidences = {inference_id2label[i]: float(prob) for i, prob in enumerate(probabilities)}
133
+ return confidences
134
+
135
+ except Exception as e:
136
+ print(f"Error during prediction: {e}")
137
+ # Return error in a format Gradio Label can display
138
+ return {"Error": f"Prediction failed: {str(e)}"}
139
+
140
+
141
+ # --- Create and Launch the Gradio Interface ---
142
+ # (This part remains the same, but title/description updated slightly)
143
+ if inference_model and inference_id2label and num_labels > 0:
144
+ print("\nSetting up Gradio Interface...")
145
+ try:
146
+ iface = gr.Interface(
147
+ fn=predict,
148
+ inputs=gr.Image(type="pil", label="Upload Face Image"),
149
+ outputs=gr.Label(num_top_classes=num_labels, label="Prediction (Real/Fake)"),
150
+ # Updated Title/Description
151
+ title="Real vs. Fake Face Detector (Loaded from HF Directory)",
152
+ description="Upload an image of a face to classify it as real or fake using the fine-tuned ViT model loaded from the 'best-model-hf' directory.",
153
+ # examples=[...] # Optional: Add example image paths if you upload some
154
+ )
155
+
156
+ print("Launching Gradio interface...")
157
+ print("Access the interface through the public URL generated below (if sharing is enabled) or the local URL.")
158
+ iface.launch(share=True, debug=True)
159
+
160
+ except Exception as e:
161
+ print(f"Error creating or launching Gradio interface: {e}")
162
+
163
+ else:
164
+ print("\nCould not launch Gradio interface because the model or label mapping failed to load.")
165
+ print("Please check the error messages above.")
166
+
167
+ # Keep the cell running to keep the Gradio interface active
168
+ print("\nGradio setup finished. Interface should be running or an error reported above.")
169
+ print("Stop this cell execution in Colab to shut down the Gradio server.")
best-model-hf/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTForImageClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "encoder_stride": 16,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.0,
9
+ "hidden_size": 768,
10
+ "id2label": {
11
+ "0": "training_fake",
12
+ "1": "training_real"
13
+ },
14
+ "image_size": 224,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "label2id": {
18
+ "training_fake": 0,
19
+ "training_real": 1
20
+ },
21
+ "layer_norm_eps": 1e-12,
22
+ "model_type": "vit",
23
+ "num_attention_heads": 12,
24
+ "num_channels": 3,
25
+ "num_hidden_layers": 12,
26
+ "patch_size": 16,
27
+ "pooler_act": "tanh",
28
+ "pooler_output_size": 768,
29
+ "problem_type": "single_label_classification",
30
+ "qkv_bias": true,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.51.1"
33
+ }
best-model-hf/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30f9d1285ccca5943c7a5d7e077a84aac64bff767b531606722cec680b67d8b0
3
+ size 343223968
best-model-hf/preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "ViTFeatureExtractor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 224,
21
+ "width": 224
22
+ }
23
+ }
best-model-hf/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:701595fdeadb1794aea348ac18c140fc7acf99a103c1f1030642b7f487ee4a71
3
+ size 5368
requirements.txt ADDED
File without changes