ruminasval commited on
Commit
3f264c3
·
verified ·
1 Parent(s): 966bfbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -26
app.py CHANGED
@@ -26,6 +26,15 @@ glasses_images = {
26
  "Pilot (Aviator)": "glasses/aviator.jpg"
27
  }
28
 
 
 
 
 
 
 
 
 
 
29
  id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
30
  label2id = {v: k for k, v in id2label.items()}
31
 
@@ -42,10 +51,14 @@ model = SwinForImageClassification.from_pretrained(
42
  )
43
 
44
  # Load your trained weights
45
- state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device)
46
- model.load_state_dict(state_dict, strict=False)
47
- model.to(device)
48
- model.eval()
 
 
 
 
49
 
50
  # Initialize Mediapipe
51
  mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
@@ -105,8 +118,18 @@ def predict(image):
105
  # --- Use decision tree for recommendations
106
  frame_recommendations = recommend_glasses_tree(pred_label)
107
 
108
- description = face_shape_descriptions[pred_label]
109
- frame_image_path = glasses_images.get(pred_label)
 
 
 
 
 
 
 
 
 
 
110
 
111
  # Build explanation text
112
  if frame_recommendations:
@@ -118,29 +141,32 @@ def predict(image):
118
  explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
119
  f"Tidak ada rekomendasi frame untuk bentuk wajah ini.")
120
 
121
- # Load frame image if available
122
- if frame_image_path and os.path.exists(frame_image_path):
123
- frame_image = Image.open(frame_image_path)
124
- else:
125
- frame_image = None
126
-
127
- return pred_label, explanation, frame_image
128
 
 
 
129
  except Exception as e:
130
- return "Error", str(e), None
131
 
132
  # Gradio Interface
133
- iface = gr.Interface(
134
- fn=predict,
135
- inputs=gr.Image(type="pil"),
136
- outputs=[
137
- gr.Textbox(label="Bentuk Wajah Terdeteksi"),
138
- gr.Textbox(label="Rekomendasi dan Penjelasan"),
139
- gr.Image(label="Gambar Frame Rekomendasi")
140
- ],
141
- title="Rekomendasi Kacamata Berdasarkan Bentuk Wajah",
142
- description="Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai!"
143
- )
 
 
 
 
 
 
 
144
 
145
  if __name__ == "__main__":
146
- iface.launch()
 
26
  "Pilot (Aviator)": "glasses/aviator.jpg"
27
  }
28
 
29
+ # Ensure the 'glasses' directory exists and contains the images
30
+ if not os.path.exists("glasses"):
31
+ os.makedirs("glasses")
32
+ # Create dummy image files if they don't exist
33
+ for _, path in glasses_images.items():
34
+ if not os.path.exists(path):
35
+ dummy_image = Image.new('RGB', (200, 100), color='gray')
36
+ dummy_image.save(path)
37
+
38
  id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
39
  label2id = {v: k for k, v in id2label.items()}
40
 
 
51
  )
52
 
53
  # Load your trained weights
54
+ # Ensure 'LR-0001-adamW-32-64swin.pth' is in the same directory or provide the correct path
55
+ if os.path.exists('LR-0001-adamW-32-64swin.pth'):
56
+ state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device)
57
+ model.load_state_dict(state_dict, strict=False)
58
+ model.to(device)
59
+ model.eval()
60
+ else:
61
+ print("Warning: Trained weights file 'LR-0001-adamW-32-64swin.pth' not found. Using pre-trained weights only.")
62
 
63
  # Initialize Mediapipe
64
  mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
 
118
  # --- Use decision tree for recommendations
119
  frame_recommendations = recommend_glasses_tree(pred_label)
120
 
121
+ description = face_shape_descriptions.get(pred_label, "tidak dikenali")
122
+ frame_images = []
123
+
124
+ # Load images for all recommended frames
125
+ for frame in frame_recommendations:
126
+ frame_image_path = glasses_images.get(frame)
127
+ if frame_image_path and os.path.exists(frame_image_path):
128
+ try:
129
+ frame_image = Image.open(frame_image_path)
130
+ frame_images.append(frame_image)
131
+ except Exception as e:
132
+ print(f"Error loading image for {frame}: {e}")
133
 
134
  # Build explanation text
135
  if frame_recommendations:
 
141
  explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
142
  f"Tidak ada rekomendasi frame untuk bentuk wajah ini.")
143
 
144
+ return pred_label, explanation, frame_images
 
 
 
 
 
 
145
 
146
+ except ValueError as ve:
147
+ return "Error", str(ve), []
148
  except Exception as e:
149
+ return "Error", f"Terjadi kesalahan: {str(e)}", []
150
 
151
  # Gradio Interface
152
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
153
+ gr.Markdown("# Program Rekomendasi Kacamata Berdasarkan Bentuk Wajah")
154
+ gr.Markdown("Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai.")
155
+
156
+ with gr.Row():
157
+ with gr.Column():
158
+ image_input = gr.Image(type="pil")
159
+ with gr.Row():
160
+ upload_button = gr.UploadButton("Unggah Gambar", file_types=["image"])
161
+ clear_button = gr.Button("Ganti")
162
+ with gr.Column():
163
+ detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi")
164
+ explanation_output = gr.Textbox(label="Penjelasan")
165
+ recommendation_gallery = gr.Gallery(label="Rekomendasi Kacamata", columns=3)
166
+
167
+ upload_button.upload(predict, inputs=upload_button, outputs=[detected_shape, explanation_output, recommendation_gallery])
168
+ clear_button.click(lambda: (None, "", []), inputs=None, outputs=[image_input, detected_shape, recommendation_gallery])
169
+ image_input.upload(predict, inputs=image_input, outputs=[detected_shape, explanation_output, recommendation_gallery])
170
 
171
  if __name__ == "__main__":
172
+ iface.launch()