ruminasval commited on
Commit
9374fca
·
verified ·
1 Parent(s): 54d75f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import torch
4
+ from transformers import SwinForImageClassification, AutoFeatureExtractor
5
+ import cv2
6
+ import mediapipe as mp
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ import gradio as gr
10
+ import os
11
+ import numpy as np
12
+
13
+ # Initialize id2label and label2id
14
+ id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
15
+ label2id = {v: k for k, v in id2label.items()}
16
+
17
+ # Initialize glasses recommendations
18
+ glasses_recommendations = {
19
+ "Heart": "Frame Rimless",
20
+ "Oblong": "Frame Persegi Panjang",
21
+ "Oval": "Frame Bulat",
22
+ "Round": "Frame Kotak",
23
+ "Square": "Frame Oval"
24
+ }
25
+
26
+ # Glasses images should be in the repo (e.g., "glasses/Heart.jpg")
27
+ glasses_images = {
28
+ "Heart": "glasses/RimlessFrame.jpg",
29
+ "Oblong": "glasses/RectangleFrame.jpg",
30
+ "Oval": "glasses/RoundFrame.jpg",
31
+ "Round": "glasses/SquareFrame.jpg",
32
+ "Square": "glasses/OvalFrame.jpg"
33
+ }
34
+
35
+ # Load model
36
+ model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ model = SwinForImageClassification.from_pretrained(
40
+ model_checkpoint,
41
+ label2id=label2id,
42
+ id2label=id2label,
43
+ ignore_mismatched_sizes=True
44
+ )
45
+
46
+ # Load your fine-tuned model weights (uploaded into Space!)
47
+ model.load_state_dict(torch.load('LR-0001-adamW-32-64swin.pth', map_location=device), strict=False)
48
+ model = model.to(device)
49
+ model.eval()
50
+
51
+ # Initialize feature extractor
52
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
53
+
54
+ # Initialize Mediapipe Face Detection
55
+ mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
56
+
57
+ # Preprocess image
58
+ def preprocess_image(image):
59
+ image = np.array(image)
60
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
61
+ results = mp_face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
62
+
63
+ if results.detections:
64
+ detection = results.detections[0]
65
+ bbox = detection.location_data.relative_bounding_box
66
+ h, w, _ = image.shape
67
+ x1 = int(bbox.xmin * w)
68
+ y1 = int(bbox.ymin * h)
69
+ x2 = int((bbox.xmin + bbox.width) * w)
70
+ y2 = int((bbox.ymin + bbox.height) * h)
71
+
72
+ face = image[y1:y2, x1:x2]
73
+ else:
74
+ raise ValueError("No face detected in the image.")
75
+
76
+ face = cv2.resize(face, (224, 224))
77
+ face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
78
+ pixel_values = feature_extractor(images=face, return_tensors="pt")['pixel_values']
79
+
80
+ return pixel_values.squeeze(0)
81
+
82
+ # Prediction
83
+ def predict(image):
84
+ try:
85
+ image_tensor = preprocess_image(image)
86
+ image_tensor = image_tensor.unsqueeze(0).to(device)
87
+
88
+ with torch.no_grad():
89
+ outputs = model(image_tensor)
90
+ logits = outputs.logits
91
+ probabilities = torch.nn.functional.softmax(logits, dim=1).squeeze(0)
92
+ sorted_probs = sorted([(id2label[i], probabilities[i].item() * 100) for i in range(len(probabilities))], key=lambda x: x[1], reverse=True)
93
+
94
+ predicted_label, predicted_prob = sorted_probs[0]
95
+ all_probs = {label: (f"{prob:.2f}%", glasses_recommendations[label]) for label, prob in sorted_probs}
96
+
97
+ # Prepare result text
98
+ result_text = f"Bentuk Wajah: {predicted_label} ({predicted_prob:.2f}%)\n\n"
99
+ result_text += "Probabilitas Setiap Kelas:\n"
100
+ for label, (prob, recommendation) in all_probs.items():
101
+ result_text += f"{label}: {prob} - Rekomendasi Kacamata: {recommendation}\n"
102
+
103
+ # Prepare glasses image
104
+ glasses_image_path = glasses_images.get(predicted_label, None)
105
+ glasses_img = None
106
+ if glasses_image_path and os.path.exists(glasses_image_path):
107
+ glasses_img = Image.open(glasses_image_path)
108
+
109
+ return result_text, glasses_img
110
+
111
+ except Exception as e:
112
+ return f"Error: {str(e)}", None
113
+
114
+ # Gradio Interface
115
+ demo = gr.Interface(
116
+ fn=predict,
117
+ inputs=gr.Image(type="pil"),
118
+ outputs=[gr.Textbox(label="Hasil Prediksi"), gr.Image(label="Rekomendasi Kacamata")],
119
+ title="Deteksi Bentuk Wajah & Rekomendasi Kacamata",
120
+ description="Upload gambar wajahmu untuk mendapatkan bentuk wajah dan rekomendasi kacamata!"
121
+ )
122
+
123
+ if __name__ == "__main__":
124
+ demo.launch()