Spaces:
Running
Running
File size: 10,682 Bytes
7e36dea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
import torch
import torchvision.transforms as T
import numpy as np
import cv2
import streamlit as st
import mediapipe as mp
from PIL import Image
import os
torch.classes.__path__ = []
class FaceHairSegmenter:
def __init__(self):
# Use MediaPipe for face detection
self.mp_face_detection = mp.solutions.face_detection
self.face_detection = self.mp_face_detection.FaceDetection(
model_selection=1, # Use full range model
min_detection_confidence=0.6
)
# Load BiSeNet model
self.model = self.load_model()
# Define transforms - adjust according to BiSeNet requirements
self.transform = T.Compose([
T.Resize((512, 512)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# CelebAMask-HQ classes - focus on the categories we want to keep
self.keep_classes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 17, 18] # All except 0, 14, 16
def load_model(self):
try:
# Import locally to avoid dependency issues if model isn't present
from model import BiSeNet
# Initialize BiSeNet with 19 classes (for CelebAMask-HQ)
model = BiSeNet(n_classes=19)
# Try to load the pretrained weights using a safer approach
try:
# First attempt: standard loading
model.load_state_dict(torch.load('bisenet.pth', map_location=torch.device('cpu')))
except RuntimeError as e:
if "__path__._path" in str(e):
# Alternative loading approach if we encounter the class path error
print("Using alternative model loading approach...")
checkpoint = torch.load('bisenet.pth', map_location='cpu', weights_only=True)
model.load_state_dict(checkpoint)
else:
# Other type of RuntimeError, re-raise
raise
model.eval()
if torch.cuda.is_available():
model = model.cuda()
print("BiSeNet model loaded successfully")
return model
except Exception as e:
print(f"Error loading model: {e}")
# Let's provide a more detailed error message to help with debugging
import traceback
traceback.print_exc()
return None
def detect_faces(self, image):
"""Detect faces using MediaPipe (expects image in RGB)."""
# Since image from cv2 is BGR, convert to RGB for MediaPipe
image_rgb = image if len(image.shape) == 3 and image.shape[2] == 3 else cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image.shape[:2]
# Process with MediaPipe
results = self.face_detection.process(image_rgb)
bboxes = []
if results.detections:
for detection in results.detections:
bbox = detection.location_data.relative_bounding_box
x_min = max(0, int(bbox.xmin * w))
y_min = max(0, int(bbox.ymin * h))
x_max = min(w, int((bbox.xmin + bbox.width) * w))
y_max = min(h, int((bbox.ymin + bbox.height) * h))
bboxes.append((x_min, y_min, x_max, y_max))
if len(bboxes) > 1:
bboxes = self.remove_overlapping_boxes(bboxes)
return len(bboxes), bboxes
def remove_overlapping_boxes(self, boxes, overlap_threshold=0.5):
if not boxes:
return []
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
boxes = sorted(boxes, key=box_area, reverse=True)
keep = []
for current in boxes:
is_duplicate = False
for kept_box in keep:
x1 = max(current[0], kept_box[0])
y1 = max(current[1], kept_box[1])
x2 = min(current[2], kept_box[2])
y2 = min(current[3], kept_box[3])
if x1 < x2 and y1 < y2:
intersection = (x2 - x1) * (y2 - y1)
area1 = box_area(current)
area2 = box_area(kept_box)
union = area1 + area2 - intersection
iou = intersection / union
if iou > overlap_threshold:
is_duplicate = True
break
if not is_duplicate:
keep.append(current)
return keep
def segment_face_hair(self, image):
"""Segment face using BiSeNet trained on CelebAMask-HQ."""
if self.model is None:
return image, "Model not loaded correctly."
if image is None or image.size == 0:
return image, "Invalid image provided."
# Detect faces
num_faces, bboxes = self.detect_faces(image)
if num_faces == 0:
return image, "No face detected! Please upload an image with a clear face."
elif num_faces > 1:
debug_img = image.copy()
for (x_min, y_min, x_max, y_max) in bboxes:
cv2.rectangle(debug_img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
return debug_img, f"{num_faces} faces detected! Please upload an image with exactly ONE face."
# Get the face bounding box (we'll use this only for ROI, not for final segmentation)
bbox = bboxes[0]
x_min, y_min, x_max, y_max = bbox
h, w = image.shape[:2]
# Expand bounding box for better segmentation
face_height = y_max - y_min + 550
face_width = x_max - x_min + 550
y_min_exp = max(0, y_min - int(face_height * 0.5)) # Expand more for hair
x_min_exp = max(0, x_min - int(face_width * 0.3))
x_max_exp = min(w, x_max + int(face_width * 0.3))
y_max_exp = min(h, y_max + int(face_height * 0.2))
# Crop and prepare image for BiSeNet
face_region = image[y_min_exp:y_max_exp, x_min_exp:x_max_exp]
original_face_size = face_region.shape[:2]
# Ensure RGB format for PIL
if face_region.shape[2] == 3:
pil_face = Image.fromarray(face_region)
else:
pil_face = Image.fromarray(cv2.cvtColor(face_region, cv2.COLOR_BGR2RGB))
# Apply transformations and run model
input_tensor = self.transform(pil_face).unsqueeze(0)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
with torch.no_grad():
out = self.model(input_tensor)[0]
parsing = out.squeeze(0).argmax(0).byte().cpu().numpy()
# Resize parsing map back to original size
parsing = cv2.resize(parsing, (original_face_size[1], original_face_size[0]),
interpolation=cv2.INTER_NEAREST)
# Create mask that keeps only the classes we want
mask = np.zeros_like(parsing, dtype=np.uint8)
for cls_id in self.keep_classes:
mask[parsing == cls_id] = 255
# Refine the mask
kernel = np.ones((3, 3), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
# Create full image mask (initialize with zeros)
full_mask = np.zeros((h, w), dtype=np.uint8)
# Place the face mask in the right position
full_mask[y_min_exp:y_max_exp, x_min_exp:x_max_exp] = mask
# Create the RGBA output
if image.shape[2] == 3: # RGB
rgba = np.dstack((image, np.zeros((h, w), dtype=np.uint8)))
# Copy only the face region with its alpha
rgba[y_min_exp:y_max_exp, x_min_exp:x_max_exp, 3] = mask
else: # Already RGBA or other format
rgba = np.dstack((cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
np.zeros((h, w), dtype=np.uint8)))
rgba[y_min_exp:y_max_exp, x_min_exp:x_max_exp, 3] = mask
return rgba, "Face segmented successfully!"
# Streamlit app
def main():
st.set_page_config(page_title="Face Segmentation Tool", layout="wide")
st.title("Face Segmentation Tool")
st.markdown("""
Upload an image to extract the face with a transparent background.
## Guidelines:
- Upload an image with **exactly one face**
- The face should be clearly visible
- For best results, use images with good lighting
""")
col1, col2 = st.columns(2)
with col1:
st.header("Input Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Convert to numpy array
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
st.image(image, caption="Uploaded Image", use_container_width=True)
if st.button("Segment Face"):
with st.spinner("Processing..."):
segmenter = FaceHairSegmenter()
result, message = segmenter.segment_face_hair(image)
with col2:
st.header("Segmented Result")
st.image(result, caption="Segmented Face", use_container_width=True)
st.text(message)
# Add download button for the result
if "No face detected" not in message and "faces detected" not in message:
# Convert numpy array to PIL Image
result_img = Image.fromarray(result)
# Create a BytesIO object
from io import BytesIO
buf = BytesIO()
result_img.save(buf, format="PNG")
# Add download button
st.download_button(
label="Download Segmented Face",
data=buf.getvalue(),
file_name="segmented_face.png",
mime="image/png"
)
if __name__ == "__main__":
main() |