import torch import torchvision.transforms as T import numpy as np import cv2 import streamlit as st import mediapipe as mp from PIL import Image import os torch.classes.__path__ = [] class FaceHairSegmenter: def __init__(self): # Use MediaPipe for face detection self.mp_face_detection = mp.solutions.face_detection self.face_detection = self.mp_face_detection.FaceDetection( model_selection=1, # Use full range model min_detection_confidence=0.6 ) # Load BiSeNet model self.model = self.load_model() # Define transforms - adjust according to BiSeNet requirements self.transform = T.Compose([ T.Resize((512, 512)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # CelebAMask-HQ classes - focus on the categories we want to keep self.keep_classes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 17, 18] # All except 0, 14, 16 def load_model(self): try: # Import locally to avoid dependency issues if model isn't present from model import BiSeNet # Initialize BiSeNet with 19 classes (for CelebAMask-HQ) model = BiSeNet(n_classes=19) # Try to load the pretrained weights using a safer approach try: # First attempt: standard loading model.load_state_dict(torch.load('bisenet.pth', map_location=torch.device('cpu'))) except RuntimeError as e: if "__path__._path" in str(e): # Alternative loading approach if we encounter the class path error print("Using alternative model loading approach...") checkpoint = torch.load('bisenet.pth', map_location='cpu', weights_only=True) model.load_state_dict(checkpoint) else: # Other type of RuntimeError, re-raise raise model.eval() if torch.cuda.is_available(): model = model.cuda() print("BiSeNet model loaded successfully") return model except Exception as e: print(f"Error loading model: {e}") # Let's provide a more detailed error message to help with debugging import traceback traceback.print_exc() return None def detect_faces(self, image): """Detect faces using MediaPipe (expects image in RGB).""" # Since image from cv2 is BGR, convert to RGB for MediaPipe image_rgb = image if len(image.shape) == 3 and image.shape[2] == 3 else cv2.cvtColor(image, cv2.COLOR_BGR2RGB) h, w = image.shape[:2] # Process with MediaPipe results = self.face_detection.process(image_rgb) bboxes = [] if results.detections: for detection in results.detections: bbox = detection.location_data.relative_bounding_box x_min = max(0, int(bbox.xmin * w)) y_min = max(0, int(bbox.ymin * h)) x_max = min(w, int((bbox.xmin + bbox.width) * w)) y_max = min(h, int((bbox.ymin + bbox.height) * h)) bboxes.append((x_min, y_min, x_max, y_max)) if len(bboxes) > 1: bboxes = self.remove_overlapping_boxes(bboxes) return len(bboxes), bboxes def remove_overlapping_boxes(self, boxes, overlap_threshold=0.5): if not boxes: return [] def box_area(box): return (box[2] - box[0]) * (box[3] - box[1]) boxes = sorted(boxes, key=box_area, reverse=True) keep = [] for current in boxes: is_duplicate = False for kept_box in keep: x1 = max(current[0], kept_box[0]) y1 = max(current[1], kept_box[1]) x2 = min(current[2], kept_box[2]) y2 = min(current[3], kept_box[3]) if x1 < x2 and y1 < y2: intersection = (x2 - x1) * (y2 - y1) area1 = box_area(current) area2 = box_area(kept_box) union = area1 + area2 - intersection iou = intersection / union if iou > overlap_threshold: is_duplicate = True break if not is_duplicate: keep.append(current) return keep def segment_face_hair(self, image): """Segment face using BiSeNet trained on CelebAMask-HQ.""" if self.model is None: return image, "Model not loaded correctly." if image is None or image.size == 0: return image, "Invalid image provided." # Detect faces num_faces, bboxes = self.detect_faces(image) if num_faces == 0: return image, "No face detected! Please upload an image with a clear face." elif num_faces > 1: debug_img = image.copy() for (x_min, y_min, x_max, y_max) in bboxes: cv2.rectangle(debug_img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2) return debug_img, f"{num_faces} faces detected! Please upload an image with exactly ONE face." # Get the face bounding box (we'll use this only for ROI, not for final segmentation) bbox = bboxes[0] x_min, y_min, x_max, y_max = bbox h, w = image.shape[:2] # Expand bounding box for better segmentation face_height = y_max - y_min + 550 face_width = x_max - x_min + 550 y_min_exp = max(0, y_min - int(face_height * 0.5)) # Expand more for hair x_min_exp = max(0, x_min - int(face_width * 0.3)) x_max_exp = min(w, x_max + int(face_width * 0.3)) y_max_exp = min(h, y_max + int(face_height * 0.2)) # Crop and prepare image for BiSeNet face_region = image[y_min_exp:y_max_exp, x_min_exp:x_max_exp] original_face_size = face_region.shape[:2] # Ensure RGB format for PIL if face_region.shape[2] == 3: pil_face = Image.fromarray(face_region) else: pil_face = Image.fromarray(cv2.cvtColor(face_region, cv2.COLOR_BGR2RGB)) # Apply transformations and run model input_tensor = self.transform(pil_face).unsqueeze(0) if torch.cuda.is_available(): input_tensor = input_tensor.cuda() with torch.no_grad(): out = self.model(input_tensor)[0] parsing = out.squeeze(0).argmax(0).byte().cpu().numpy() # Resize parsing map back to original size parsing = cv2.resize(parsing, (original_face_size[1], original_face_size[0]), interpolation=cv2.INTER_NEAREST) # Create mask that keeps only the classes we want mask = np.zeros_like(parsing, dtype=np.uint8) for cls_id in self.keep_classes: mask[parsing == cls_id] = 255 # Refine the mask kernel = np.ones((3, 3), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Create full image mask (initialize with zeros) full_mask = np.zeros((h, w), dtype=np.uint8) # Place the face mask in the right position full_mask[y_min_exp:y_max_exp, x_min_exp:x_max_exp] = mask # Create the RGBA output if image.shape[2] == 3: # RGB rgba = np.dstack((image, np.zeros((h, w), dtype=np.uint8))) # Copy only the face region with its alpha rgba[y_min_exp:y_max_exp, x_min_exp:x_max_exp, 3] = mask else: # Already RGBA or other format rgba = np.dstack((cv2.cvtColor(image, cv2.COLOR_BGR2RGB), np.zeros((h, w), dtype=np.uint8))) rgba[y_min_exp:y_max_exp, x_min_exp:x_max_exp, 3] = mask return rgba, "Face segmented successfully!" # Streamlit app def main(): st.set_page_config(page_title="Face Segmentation Tool", layout="wide") st.title("Face Segmentation Tool") st.markdown(""" Upload an image to extract the face with a transparent background. ## Guidelines: - Upload an image with **exactly one face** - The face should be clearly visible - For best results, use images with good lighting """) col1, col2 = st.columns(2) with col1: st.header("Input Image") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Convert to numpy array file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) st.image(image, caption="Uploaded Image", use_container_width=True) if st.button("Segment Face"): with st.spinner("Processing..."): segmenter = FaceHairSegmenter() result, message = segmenter.segment_face_hair(image) with col2: st.header("Segmented Result") st.image(result, caption="Segmented Face", use_container_width=True) st.text(message) # Add download button for the result if "No face detected" not in message and "faces detected" not in message: # Convert numpy array to PIL Image result_img = Image.fromarray(result) # Create a BytesIO object from io import BytesIO buf = BytesIO() result_img.save(buf, format="PNG") # Add download button st.download_button( label="Download Segmented Face", data=buf.getvalue(), file_name="segmented_face.png", mime="image/png" ) if __name__ == "__main__": main()