File size: 10,682 Bytes
7e36dea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import torch
import torchvision.transforms as T
import numpy as np
import cv2
import streamlit as st
import mediapipe as mp
from PIL import Image
import os
torch.classes.__path__ = []

class FaceHairSegmenter:
    def __init__(self):
        # Use MediaPipe for face detection
        self.mp_face_detection = mp.solutions.face_detection
        self.face_detection = self.mp_face_detection.FaceDetection(
            model_selection=1,  # Use full range model
            min_detection_confidence=0.6
        )
        
        # Load BiSeNet model
        self.model = self.load_model()
        
        # Define transforms - adjust according to BiSeNet requirements
        self.transform = T.Compose([
            T.Resize((512, 512)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # CelebAMask-HQ classes - focus on the categories we want to keep
        self.keep_classes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 17, 18]  # All except 0, 14, 16

    def load_model(self):
        try:
            # Import locally to avoid dependency issues if model isn't present
            from model import BiSeNet
            
            # Initialize BiSeNet with 19 classes (for CelebAMask-HQ)
            model = BiSeNet(n_classes=19)
            
            # Try to load the pretrained weights using a safer approach
            try:
                # First attempt: standard loading
                model.load_state_dict(torch.load('bisenet.pth', map_location=torch.device('cpu')))
            except RuntimeError as e:
                if "__path__._path" in str(e):
                    # Alternative loading approach if we encounter the class path error
                    print("Using alternative model loading approach...")
                    checkpoint = torch.load('bisenet.pth', map_location='cpu', weights_only=True)
                    model.load_state_dict(checkpoint)
                else:
                    # Other type of RuntimeError, re-raise
                    raise
            
            model.eval()
            
            if torch.cuda.is_available():
                model = model.cuda()
                
            print("BiSeNet model loaded successfully")
            return model
        except Exception as e:
            print(f"Error loading model: {e}")
            # Let's provide a more detailed error message to help with debugging
            import traceback
            traceback.print_exc()
            return None

    def detect_faces(self, image):
        """Detect faces using MediaPipe (expects image in RGB)."""
        # Since image from cv2 is BGR, convert to RGB for MediaPipe
        image_rgb = image if len(image.shape) == 3 and image.shape[2] == 3 else cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w = image.shape[:2]
        
        # Process with MediaPipe
        results = self.face_detection.process(image_rgb)
        
        bboxes = []
        if results.detections:
            for detection in results.detections:
                bbox = detection.location_data.relative_bounding_box
                x_min = max(0, int(bbox.xmin * w))
                y_min = max(0, int(bbox.ymin * h))
                x_max = min(w, int((bbox.xmin + bbox.width) * w))
                y_max = min(h, int((bbox.ymin + bbox.height) * h))
                bboxes.append((x_min, y_min, x_max, y_max))
        
        if len(bboxes) > 1:
            bboxes = self.remove_overlapping_boxes(bboxes)
            
        return len(bboxes), bboxes

    def remove_overlapping_boxes(self, boxes, overlap_threshold=0.5):
        if not boxes:
            return []
        def box_area(box):
            return (box[2] - box[0]) * (box[3] - box[1])
        boxes = sorted(boxes, key=box_area, reverse=True)
        keep = []
        for current in boxes:
            is_duplicate = False
            for kept_box in keep:
                x1 = max(current[0], kept_box[0])
                y1 = max(current[1], kept_box[1])
                x2 = min(current[2], kept_box[2])
                y2 = min(current[3], kept_box[3])
                if x1 < x2 and y1 < y2:
                    intersection = (x2 - x1) * (y2 - y1)
                    area1 = box_area(current)
                    area2 = box_area(kept_box)
                    union = area1 + area2 - intersection
                    iou = intersection / union
                    if iou > overlap_threshold:
                        is_duplicate = True
                        break
            if not is_duplicate:
                keep.append(current)
        return keep

    def segment_face_hair(self, image):
        """Segment face using BiSeNet trained on CelebAMask-HQ."""
        if self.model is None:
            return image, "Model not loaded correctly."
        if image is None or image.size == 0:
            return image, "Invalid image provided."
        
        # Detect faces
        num_faces, bboxes = self.detect_faces(image)
        if num_faces == 0:
            return image, "No face detected! Please upload an image with a clear face."
        elif num_faces > 1:
            debug_img = image.copy()
            for (x_min, y_min, x_max, y_max) in bboxes:
                cv2.rectangle(debug_img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
            return debug_img, f"{num_faces} faces detected! Please upload an image with exactly ONE face."
        
        # Get the face bounding box (we'll use this only for ROI, not for final segmentation)
        bbox = bboxes[0]
        x_min, y_min, x_max, y_max = bbox
        h, w = image.shape[:2]
        
        # Expand bounding box for better segmentation
        face_height = y_max - y_min + 550
        face_width = x_max - x_min + 550
        
        y_min_exp = max(0, y_min - int(face_height * 0.5))  # Expand more for hair
        x_min_exp = max(0, x_min - int(face_width * 0.3))
        x_max_exp = min(w, x_max + int(face_width * 0.3))
        y_max_exp = min(h, y_max + int(face_height * 0.2))
        
        # Crop and prepare image for BiSeNet
        face_region = image[y_min_exp:y_max_exp, x_min_exp:x_max_exp]
        original_face_size = face_region.shape[:2]
        
        # Ensure RGB format for PIL
        if face_region.shape[2] == 3:
            pil_face = Image.fromarray(face_region)
        else:
            pil_face = Image.fromarray(cv2.cvtColor(face_region, cv2.COLOR_BGR2RGB))
            
        # Apply transformations and run model
        input_tensor = self.transform(pil_face).unsqueeze(0)
        if torch.cuda.is_available():
            input_tensor = input_tensor.cuda()
            
        with torch.no_grad():
            out = self.model(input_tensor)[0]
            parsing = out.squeeze(0).argmax(0).byte().cpu().numpy()
        
        # Resize parsing map back to original size
        parsing = cv2.resize(parsing, (original_face_size[1], original_face_size[0]), 
                            interpolation=cv2.INTER_NEAREST)
        
        # Create mask that keeps only the classes we want
        mask = np.zeros_like(parsing, dtype=np.uint8)
        for cls_id in self.keep_classes:
            mask[parsing == cls_id] = 255
            
        # Refine the mask
        kernel = np.ones((3, 3), np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        
        # Create full image mask (initialize with zeros)
        full_mask = np.zeros((h, w), dtype=np.uint8)
        # Place the face mask in the right position
        full_mask[y_min_exp:y_max_exp, x_min_exp:x_max_exp] = mask
        
        # Create the RGBA output
        if image.shape[2] == 3:  # RGB
            rgba = np.dstack((image, np.zeros((h, w), dtype=np.uint8)))
            # Copy only the face region with its alpha
            rgba[y_min_exp:y_max_exp, x_min_exp:x_max_exp, 3] = mask
        else:  # Already RGBA or other format
            rgba = np.dstack((cv2.cvtColor(image, cv2.COLOR_BGR2RGB), 
                             np.zeros((h, w), dtype=np.uint8)))
            rgba[y_min_exp:y_max_exp, x_min_exp:x_max_exp, 3] = mask
            
        return rgba, "Face segmented successfully!"

# Streamlit app
def main():
    st.set_page_config(page_title="Face Segmentation Tool", layout="wide")
    
    st.title("Face Segmentation Tool")
    st.markdown("""
    Upload an image to extract the face with a transparent background.
    
    ## Guidelines:
    - Upload an image with **exactly one face**
    - The face should be clearly visible
    - For best results, use images with good lighting
    """)
    
    col1, col2 = st.columns(2)
    
    with col1:
        st.header("Input Image")
        uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
        
        if uploaded_file is not None:
            # Convert to numpy array
            file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
            image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            st.image(image, caption="Uploaded Image", use_container_width=True)
            
            if st.button("Segment Face"):
                with st.spinner("Processing..."):
                    segmenter = FaceHairSegmenter()
                    result, message = segmenter.segment_face_hair(image)
                    
                    with col2:
                        st.header("Segmented Result")
                        st.image(result, caption="Segmented Face", use_container_width=True)
                        st.text(message)
                        
                        # Add download button for the result
                        if "No face detected" not in message and "faces detected" not in message:
                            # Convert numpy array to PIL Image
                            result_img = Image.fromarray(result)
                            
                            # Create a BytesIO object
                            from io import BytesIO
                            buf = BytesIO()
                            result_img.save(buf, format="PNG")
                            
                            # Add download button
                            st.download_button(
                                label="Download Segmented Face",
                                data=buf.getvalue(),
                                file_name="segmented_face.png",
                                mime="image/png"
                            )

if __name__ == "__main__":
    main()