Spaces:
Sleeping
Sleeping
import torch | |
import torchvision.transforms as T | |
import numpy as np | |
import cv2 | |
import streamlit as st | |
import mediapipe as mp | |
from PIL import Image | |
import os | |
torch.classes.__path__ = [] | |
class FaceHairSegmenter: | |
def __init__(self): | |
# Use MediaPipe for face detection | |
self.mp_face_detection = mp.solutions.face_detection | |
self.face_detection = self.mp_face_detection.FaceDetection( | |
model_selection=1, # Use full range model | |
min_detection_confidence=0.6 | |
) | |
# Load BiSeNet model | |
self.model = self.load_model() | |
# Define transforms - adjust according to BiSeNet requirements | |
self.transform = T.Compose([ | |
T.Resize((512, 512)), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# CelebAMask-HQ classes - focus on the categories we want to keep | |
self.keep_classes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 17, 18] # All except 0, 14, 16 | |
def load_model(self): | |
try: | |
# Import locally to avoid dependency issues if model isn't present | |
from model import BiSeNet | |
# Initialize BiSeNet with 19 classes (for CelebAMask-HQ) | |
model = BiSeNet(n_classes=19) | |
# Try to load the pretrained weights using a safer approach | |
try: | |
# First attempt: standard loading | |
model.load_state_dict(torch.load('bisenet.pth', map_location=torch.device('cpu'))) | |
except RuntimeError as e: | |
if "__path__._path" in str(e): | |
# Alternative loading approach if we encounter the class path error | |
print("Using alternative model loading approach...") | |
checkpoint = torch.load('bisenet.pth', map_location='cpu', weights_only=True) | |
model.load_state_dict(checkpoint) | |
else: | |
# Other type of RuntimeError, re-raise | |
raise | |
model.eval() | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
print("BiSeNet model loaded successfully") | |
return model | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Let's provide a more detailed error message to help with debugging | |
import traceback | |
traceback.print_exc() | |
return None | |
def detect_faces(self, image): | |
"""Detect faces using MediaPipe (expects image in RGB).""" | |
# Since image from cv2 is BGR, convert to RGB for MediaPipe | |
image_rgb = image if len(image.shape) == 3 and image.shape[2] == 3 else cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
h, w = image.shape[:2] | |
# Process with MediaPipe | |
results = self.face_detection.process(image_rgb) | |
bboxes = [] | |
if results.detections: | |
for detection in results.detections: | |
bbox = detection.location_data.relative_bounding_box | |
x_min = max(0, int(bbox.xmin * w)) | |
y_min = max(0, int(bbox.ymin * h)) | |
x_max = min(w, int((bbox.xmin + bbox.width) * w)) | |
y_max = min(h, int((bbox.ymin + bbox.height) * h)) | |
bboxes.append((x_min, y_min, x_max, y_max)) | |
if len(bboxes) > 1: | |
bboxes = self.remove_overlapping_boxes(bboxes) | |
return len(bboxes), bboxes | |
def remove_overlapping_boxes(self, boxes, overlap_threshold=0.5): | |
if not boxes: | |
return [] | |
def box_area(box): | |
return (box[2] - box[0]) * (box[3] - box[1]) | |
boxes = sorted(boxes, key=box_area, reverse=True) | |
keep = [] | |
for current in boxes: | |
is_duplicate = False | |
for kept_box in keep: | |
x1 = max(current[0], kept_box[0]) | |
y1 = max(current[1], kept_box[1]) | |
x2 = min(current[2], kept_box[2]) | |
y2 = min(current[3], kept_box[3]) | |
if x1 < x2 and y1 < y2: | |
intersection = (x2 - x1) * (y2 - y1) | |
area1 = box_area(current) | |
area2 = box_area(kept_box) | |
union = area1 + area2 - intersection | |
iou = intersection / union | |
if iou > overlap_threshold: | |
is_duplicate = True | |
break | |
if not is_duplicate: | |
keep.append(current) | |
return keep | |
def segment_face_hair(self, image): | |
"""Segment face using BiSeNet trained on CelebAMask-HQ.""" | |
if self.model is None: | |
return image, "Model not loaded correctly." | |
if image is None or image.size == 0: | |
return image, "Invalid image provided." | |
# Detect faces | |
num_faces, bboxes = self.detect_faces(image) | |
if num_faces == 0: | |
return image, "No face detected! Please upload an image with a clear face." | |
elif num_faces > 1: | |
debug_img = image.copy() | |
for (x_min, y_min, x_max, y_max) in bboxes: | |
cv2.rectangle(debug_img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2) | |
return debug_img, f"{num_faces} faces detected! Please upload an image with exactly ONE face." | |
# Get the face bounding box (we'll use this only for ROI, not for final segmentation) | |
bbox = bboxes[0] | |
x_min, y_min, x_max, y_max = bbox | |
h, w = image.shape[:2] | |
# Expand bounding box for better segmentation | |
face_height = y_max - y_min + 550 | |
face_width = x_max - x_min + 550 | |
y_min_exp = max(0, y_min - int(face_height * 0.5)) # Expand more for hair | |
x_min_exp = max(0, x_min - int(face_width * 0.3)) | |
x_max_exp = min(w, x_max + int(face_width * 0.3)) | |
y_max_exp = min(h, y_max + int(face_height * 0.2)) | |
# Crop and prepare image for BiSeNet | |
face_region = image[y_min_exp:y_max_exp, x_min_exp:x_max_exp] | |
original_face_size = face_region.shape[:2] | |
# Ensure RGB format for PIL | |
if face_region.shape[2] == 3: | |
pil_face = Image.fromarray(face_region) | |
else: | |
pil_face = Image.fromarray(cv2.cvtColor(face_region, cv2.COLOR_BGR2RGB)) | |
# Apply transformations and run model | |
input_tensor = self.transform(pil_face).unsqueeze(0) | |
if torch.cuda.is_available(): | |
input_tensor = input_tensor.cuda() | |
with torch.no_grad(): | |
out = self.model(input_tensor)[0] | |
parsing = out.squeeze(0).argmax(0).byte().cpu().numpy() | |
# Resize parsing map back to original size | |
parsing = cv2.resize(parsing, (original_face_size[1], original_face_size[0]), | |
interpolation=cv2.INTER_NEAREST) | |
# Create mask that keeps only the classes we want | |
mask = np.zeros_like(parsing, dtype=np.uint8) | |
for cls_id in self.keep_classes: | |
mask[parsing == cls_id] = 255 | |
# Refine the mask | |
kernel = np.ones((3, 3), np.uint8) | |
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
# Create full image mask (initialize with zeros) | |
full_mask = np.zeros((h, w), dtype=np.uint8) | |
# Place the face mask in the right position | |
full_mask[y_min_exp:y_max_exp, x_min_exp:x_max_exp] = mask | |
# Create the RGBA output | |
if image.shape[2] == 3: # RGB | |
rgba = np.dstack((image, np.zeros((h, w), dtype=np.uint8))) | |
# Copy only the face region with its alpha | |
rgba[y_min_exp:y_max_exp, x_min_exp:x_max_exp, 3] = mask | |
else: # Already RGBA or other format | |
rgba = np.dstack((cv2.cvtColor(image, cv2.COLOR_BGR2RGB), | |
np.zeros((h, w), dtype=np.uint8))) | |
rgba[y_min_exp:y_max_exp, x_min_exp:x_max_exp, 3] = mask | |
return rgba, "Face segmented successfully!" | |
# Streamlit app | |
def main(): | |
st.set_page_config(page_title="Face Segmentation Tool", layout="wide") | |
st.title("Face Segmentation Tool") | |
st.markdown(""" | |
Upload an image to extract the face with a transparent background. | |
## Guidelines: | |
- Upload an image with **exactly one face** | |
- The face should be clearly visible | |
- For best results, use images with good lighting | |
""") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Input Image") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Convert to numpy array | |
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
if st.button("Segment Face"): | |
with st.spinner("Processing..."): | |
segmenter = FaceHairSegmenter() | |
result, message = segmenter.segment_face_hair(image) | |
with col2: | |
st.header("Segmented Result") | |
st.image(result, caption="Segmented Face", use_container_width=True) | |
st.text(message) | |
# Add download button for the result | |
if "No face detected" not in message and "faces detected" not in message: | |
# Convert numpy array to PIL Image | |
result_img = Image.fromarray(result) | |
# Create a BytesIO object | |
from io import BytesIO | |
buf = BytesIO() | |
result_img.save(buf, format="PNG") | |
# Add download button | |
st.download_button( | |
label="Download Segmented Face", | |
data=buf.getvalue(), | |
file_name="segmented_face.png", | |
mime="image/png" | |
) | |
if __name__ == "__main__": | |
main() |