Faceshape / app.py
ruminasval's picture
Update app.py
93bb675 verified
import gradio as gr
import torch
from transformers import SwinForImageClassification, AutoFeatureExtractor
import mediapipe as mp
import cv2
from PIL import Image
import numpy as np
import os
# Face shape descriptions
face_shape_descriptions = {
"Heart": "dengan dahi lebar dan dagu yang runcing.",
"Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.",
"Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.",
"Round": "dengan garis rahang melengkung dan pipi penuh.",
"Square": "dengan rahang tegas dan dahi lebar."
}
# Frame images path
glasses_images = {
"Oval": "glasses/oval.jpg",
"Round": "glasses/round.jpg",
"Square": "glasses/square.jpg",
"Octagon": "glasses/octagon.jpg",
"Cat Eye": "glasses/cat eye.jpg",
"Pilot (Aviator)": "glasses/aviator.jpg"
}
# Ensure the 'glasses' directory exists and contains the images
if not os.path.exists("glasses"):
os.makedirs("glasses")
# Create dummy image files if they don't exist
for _, path in glasses_images.items():
if not os.path.exists(path):
dummy_image = Image.new('RGB', (200, 100), color='gray')
dummy_image.save(path)
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
label2id = {v: k for k, v in id2label.items()}
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
model = SwinForImageClassification.from_pretrained(
model_checkpoint,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
)
# Load your trained weights
# Ensure 'LR-0001-adamW-32-64swin.pth' is in the same directory or provide the correct path
if os.path.exists('LR-0001-adamW-32-64swin.pth'):
state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
else:
print("Warning: Trained weights file 'LR-0001-adamW-32-64swin.pth' not found. Using pre-trained weights only.")
# Initialize Mediapipe
mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
# --- New: Decision tree function
def recommend_glasses_tree(face_shape):
face_shape = face_shape.lower()
if face_shape == "square":
return ["Oval", "Round"]
elif face_shape == "round":
return ["Square", "Octagon", "Cat Eye"]
elif face_shape == "oval":
return ["Oval", "Pilot (Aviator)", "Cat Eye", "Round"]
elif face_shape == "heart":
return ["Pilot (Aviator)", "Cat Eye", "Round"]
elif face_shape == "oblong":
return ["Square", "Oval", "Pilot (Aviator)", "Cat Eye"]
else:
return []
# Preprocess function
def preprocess_image(image):
img = np.array(image, dtype=np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
results = mp_face_detection.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
if results.detections:
detection = results.detections[0]
bbox = detection.location_data.relative_bounding_box
h, w, _ = img.shape
x1 = int(bbox.xmin * w)
y1 = int(bbox.ymin * h)
x2 = int((bbox.xmin + bbox.width) * w)
y2 = int((bbox.ymin + bbox.height) * h)
img = img[y1:y2, x1:x2]
else:
raise ValueError("Wajah tidak terdeteksi.")
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
inputs = feature_extractor(images=img, return_tensors="pt")
return inputs['pixel_values'].squeeze(0)
# Prediction function
def predict(image):
try:
inputs = preprocess_image(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
pred_label = id2label[pred_idx]
pred_prob = probs[0][pred_idx].item() * 100
# --- Use decision tree for recommendations
frame_recommendations = recommend_glasses_tree(pred_label)
description = face_shape_descriptions.get(pred_label, "tidak dikenali")
gallery_items = []
# Load images for all recommended frames
for frame in frame_recommendations:
frame_image_path = glasses_images.get(frame)
if frame_image_path and os.path.exists(frame_image_path):
try:
frame_image = Image.open(frame_image_path)
gallery_items.append((frame_image, frame)) # Tambahkan nama frame
except Exception as e:
print(f"Error loading image for {frame}: {e}")
# Build explanation text
if frame_recommendations:
recommended_frames_text = ', '.join(frame_recommendations)
explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
f"Kamu memiliki bentuk wajah {description} "
f"Rekomendasi bentuk kacamata yang sesuai dengan wajah kamu adalah: {recommended_frames_text}.")
else:
explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
f"Tidak ada rekomendasi frame untuk bentuk wajah ini.")
return pred_label, explanation, gallery_items
except ValueError as ve:
return "Error", str(ve), []
except Exception as e:
return "Error", f"Terjadi kesalahan: {str(e)}", []
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("# Program Rekomendasi Kacamata Berdasarkan Bentuk Wajah")
gr.Markdown("Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil")
confirm_button = gr.Button("Konfirmasi")
restart_button = gr.Button("Restart")
with gr.Column():
detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi")
explanation_output = gr.Textbox(label="Penjelasan")
recommendation_gallery = gr.Gallery(label="Rekomendasi Kacamata", columns=3, show_label=False)
confirm_button.click(predict, inputs=image_input, outputs=[detected_shape, explanation_output, recommendation_gallery])
restart_button.click(lambda: (None, "", [], []), inputs=None, outputs=[image_input, detected_shape, explanation_output, recommendation_gallery])
# Add source statement under the gallery
gr.Markdown("**Sumber gambar kacamata**: Katalog dari [glassdirect.co.uk](https://www.glassdirect.co.uk)")
if __name__ == "__main__":
iface.launch()