Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -26,6 +26,15 @@ glasses_images = {
|
|
26 |
"Pilot (Aviator)": "glasses/aviator.jpg"
|
27 |
}
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
|
30 |
label2id = {v: k for k, v in id2label.items()}
|
31 |
|
@@ -42,10 +51,14 @@ model = SwinForImageClassification.from_pretrained(
|
|
42 |
)
|
43 |
|
44 |
# Load your trained weights
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
model.
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# Initialize Mediapipe
|
51 |
mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
|
@@ -105,8 +118,18 @@ def predict(image):
|
|
105 |
# --- Use decision tree for recommendations
|
106 |
frame_recommendations = recommend_glasses_tree(pred_label)
|
107 |
|
108 |
-
description = face_shape_descriptions
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
# Build explanation text
|
112 |
if frame_recommendations:
|
@@ -118,29 +141,32 @@ def predict(image):
|
|
118 |
explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
|
119 |
f"Tidak ada rekomendasi frame untuk bentuk wajah ini.")
|
120 |
|
121 |
-
|
122 |
-
if frame_image_path and os.path.exists(frame_image_path):
|
123 |
-
frame_image = Image.open(frame_image_path)
|
124 |
-
else:
|
125 |
-
frame_image = None
|
126 |
-
|
127 |
-
return pred_label, explanation, frame_image
|
128 |
|
|
|
|
|
129 |
except Exception as e:
|
130 |
-
return "Error", str(e),
|
131 |
|
132 |
# Gradio Interface
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
gr.
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
if __name__ == "__main__":
|
146 |
-
iface.launch()
|
|
|
26 |
"Pilot (Aviator)": "glasses/aviator.jpg"
|
27 |
}
|
28 |
|
29 |
+
# Ensure the 'glasses' directory exists and contains the images
|
30 |
+
if not os.path.exists("glasses"):
|
31 |
+
os.makedirs("glasses")
|
32 |
+
# Create dummy image files if they don't exist
|
33 |
+
for _, path in glasses_images.items():
|
34 |
+
if not os.path.exists(path):
|
35 |
+
dummy_image = Image.new('RGB', (200, 100), color='gray')
|
36 |
+
dummy_image.save(path)
|
37 |
+
|
38 |
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
|
39 |
label2id = {v: k for k, v in id2label.items()}
|
40 |
|
|
|
51 |
)
|
52 |
|
53 |
# Load your trained weights
|
54 |
+
# Ensure 'LR-0001-adamW-32-64swin.pth' is in the same directory or provide the correct path
|
55 |
+
if os.path.exists('LR-0001-adamW-32-64swin.pth'):
|
56 |
+
state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device)
|
57 |
+
model.load_state_dict(state_dict, strict=False)
|
58 |
+
model.to(device)
|
59 |
+
model.eval()
|
60 |
+
else:
|
61 |
+
print("Warning: Trained weights file 'LR-0001-adamW-32-64swin.pth' not found. Using pre-trained weights only.")
|
62 |
|
63 |
# Initialize Mediapipe
|
64 |
mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
|
|
|
118 |
# --- Use decision tree for recommendations
|
119 |
frame_recommendations = recommend_glasses_tree(pred_label)
|
120 |
|
121 |
+
description = face_shape_descriptions.get(pred_label, "tidak dikenali")
|
122 |
+
frame_images = []
|
123 |
+
|
124 |
+
# Load images for all recommended frames
|
125 |
+
for frame in frame_recommendations:
|
126 |
+
frame_image_path = glasses_images.get(frame)
|
127 |
+
if frame_image_path and os.path.exists(frame_image_path):
|
128 |
+
try:
|
129 |
+
frame_image = Image.open(frame_image_path)
|
130 |
+
frame_images.append(frame_image)
|
131 |
+
except Exception as e:
|
132 |
+
print(f"Error loading image for {frame}: {e}")
|
133 |
|
134 |
# Build explanation text
|
135 |
if frame_recommendations:
|
|
|
141 |
explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
|
142 |
f"Tidak ada rekomendasi frame untuk bentuk wajah ini.")
|
143 |
|
144 |
+
return pred_label, explanation, frame_images
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
except ValueError as ve:
|
147 |
+
return "Error", str(ve), []
|
148 |
except Exception as e:
|
149 |
+
return "Error", f"Terjadi kesalahan: {str(e)}", []
|
150 |
|
151 |
# Gradio Interface
|
152 |
+
with gr.Blocks(theme=gr.themes.Soft()) as iface:
|
153 |
+
gr.Markdown("# Program Rekomendasi Kacamata Berdasarkan Bentuk Wajah")
|
154 |
+
gr.Markdown("Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai.")
|
155 |
+
|
156 |
+
with gr.Row():
|
157 |
+
with gr.Column():
|
158 |
+
image_input = gr.Image(type="pil")
|
159 |
+
with gr.Row():
|
160 |
+
upload_button = gr.UploadButton("Unggah Gambar", file_types=["image"])
|
161 |
+
clear_button = gr.Button("Ganti")
|
162 |
+
with gr.Column():
|
163 |
+
detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi")
|
164 |
+
explanation_output = gr.Textbox(label="Penjelasan")
|
165 |
+
recommendation_gallery = gr.Gallery(label="Rekomendasi Kacamata", columns=3)
|
166 |
+
|
167 |
+
upload_button.upload(predict, inputs=upload_button, outputs=[detected_shape, explanation_output, recommendation_gallery])
|
168 |
+
clear_button.click(lambda: (None, "", []), inputs=None, outputs=[image_input, detected_shape, recommendation_gallery])
|
169 |
+
image_input.upload(predict, inputs=image_input, outputs=[detected_shape, explanation_output, recommendation_gallery])
|
170 |
|
171 |
if __name__ == "__main__":
|
172 |
+
iface.launch()
|