File size: 9,189 Bytes
7b4f310
 
 
69db8f3
7b4f310
606bdc0
5f66b26
 
69db8f3
5f66b26
 
 
d547454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f66b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d547454
 
 
 
5f66b26
 
 
 
69db8f3
5f66b26
 
 
8cfd312
5f66b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cfd312
5f66b26
8cfd312
5f66b26
 
d547454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f66b26
 
8cfd312
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Nikhil Kunjoor
"""
import gradio as gr
from transformers import pipeline
from PIL import Image, ImageFilter, ImageOps
import numpy as np
import requests
import cv2

# Dictionary of available segmentation models
SEGMENTATION_MODELS = {
    "NVIDIA SegFormer (Cityscapes)": "nvidia/segformer-b1-finetuned-cityscapes-1024-1024",
    "NVIDIA SegFormer (ADE20K)": "nvidia/segformer-b0-finetuned-ade-512-512",
    "Facebook MaskFormer (COCO)": "facebook/maskformer-swin-base-ade",
    "OneFormer (COCO)": "shi-labs/oneformer_coco_swin_large",
    "NVIDIA SegFormer (B5)": "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
}

# Dictionary of available depth estimation models
DEPTH_MODELS = {
    "Intel ZoeDepth (NYU-KITTI)": "Intel/zoedepth-nyu-kitti",
    "DPT (Large)": "Intel/dpt-large",
    "DPT (Hybrid)": "Intel/dpt-hybrid-midas",
    "GLPDepth": "vinvino02/glpn-nyu"
}

# Initialize model placeholders
segmentation_model = None
depth_estimator = None

def load_segmentation_model(model_name):
    """Load the selected segmentation model"""
    global segmentation_model
    model_path = SEGMENTATION_MODELS[model_name]
    print(f"Loading segmentation model: {model_path}...")
    segmentation_model = pipeline("image-segmentation", model=model_path)
    return f"Loaded segmentation model: {model_name}"

def load_depth_model(model_name):
    """Load the selected depth estimation model"""
    global depth_estimator
    model_path = DEPTH_MODELS[model_name]
    print(f"Loading depth estimation model: {model_path}...")
    depth_estimator = pipeline("depth-estimation", model=model_path)
    return f"Loaded depth model: {model_name}"

def lens_blur(image, radius):
    """
    Apply a more realistic lens blur (bokeh effect) using OpenCV.
    """
    if radius < 1:
        return image
    
    # Convert PIL image to OpenCV format
    img_np = np.array(image)
    
    # Create a circular kernel for the bokeh effect
    kernel_size = 2 * radius + 1
    kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
    center = radius
    for i in range(kernel_size):
        for j in range(kernel_size):
            # Create circular kernel
            if np.sqrt((i - center) ** 2 + (j - center) ** 2) <= radius:
                kernel[i, j] = 1.0
    
    # Normalize the kernel
    if kernel.sum() != 0:
        kernel = kernel / kernel.sum()
    
    # Apply the filter to each channel separately
    channels = cv2.split(img_np)
    blurred_channels = []
    
    for channel in channels:
        blurred_channel = cv2.filter2D(channel, -1, kernel)
        blurred_channels.append(blurred_channel)
    
    # Merge the channels back
    blurred_img = cv2.merge(blurred_channels)
    
    # Convert back to PIL image
    return Image.fromarray(blurred_img)

def process_image(input_image, method, blur_intensity, blur_type):
    """
    Process the input image using one of two methods:
    
    1. Segmented Background Blur:
       - Uses segmentation to extract a foreground mask.
       - Applies the selected blur (Gaussian or Lens) to the background.
       - Composites the final image.
       
    2. Depth-based Variable Blur:
       - Uses depth estimation to generate a depth map.
       - Normalizes the depth map to be used as a blending mask.
       - Blends a fully blurred version (using the selected blur) with the original image.
       
    Returns:
       - output_image: final composited image.
       - mask_image: the mask used (binary for segmentation, normalized depth for depth-based).
    """
    # Check if models are loaded
    if segmentation_model is None or depth_estimator is None:
        return input_image, input_image.convert("L")
    
    # Ensure image is in RGB mode
    input_image = input_image.convert("RGB")
    
    # Select blur function based on blur_type
    if blur_type == "Gaussian Blur":
        blur_fn = lambda img, rad: img.filter(ImageFilter.GaussianBlur(radius=rad))
    elif blur_type == "Lens Blur":
        blur_fn = lens_blur
    else:
        blur_fn = lambda img, rad: img.filter(ImageFilter.GaussianBlur(radius=rad))
    
    if method == "Segmented Background Blur":
        # Use segmentation to obtain a foreground mask.
        results = segmentation_model(input_image)
        # Assume the last result is the main foreground object.
        foreground_mask = results[-1]["mask"]
        # Ensure the mask is grayscale.
        foreground_mask = foreground_mask.convert("L")
        # Threshold to create a binary mask.
        binary_mask = foreground_mask.point(lambda p: 255 if p > 128 else 0)
        
        # Blur the background using the selected blur function.
        blurred_background = blur_fn(input_image, blur_intensity)
        
        # Composite the final image: keep foreground and use blurred background elsewhere.
        output_image = Image.composite(input_image, blurred_background, binary_mask)
        mask_image = binary_mask
        
    elif method == "Depth-based Variable Blur":
        # Generate depth map.
        depth_results = depth_estimator(input_image)
        depth_map = depth_results["depth"]
        
        # Convert depth map to numpy array and normalize to [0, 255]
        depth_array = np.array(depth_map).astype(np.float32)
        norm = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min() + 1e-8)
        normalized_depth = (norm * 255).astype(np.uint8)
        mask_image = Image.fromarray(normalized_depth)
        
        # Create fully blurred version using the selected blur function.
        blurred_image = blur_fn(input_image, blur_intensity)
        
        # Convert images to arrays for blending.
        orig_np = np.array(input_image).astype(np.float32)
        blur_np = np.array(blurred_image).astype(np.float32)
        # Reshape mask for broadcasting.
        alpha = normalized_depth[..., np.newaxis] / 255.0
        
        # Blend pixels: 0 = original; 1 = fully blurred.
        blended_np = (1 - alpha) * orig_np + alpha * blur_np
        blended_np = np.clip(blended_np, 0, 255).astype(np.uint8)
        output_image = Image.fromarray(blended_np)
    
    else:
        output_image = input_image
        mask_image = input_image.convert("L")
    
    return output_image, mask_image

# Build a Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Image Processing App: Segmentation & Depth-based Blur")
    
    with gr.Tab("Model Selection"):
        with gr.Row():
            with gr.Column():
                seg_model_dropdown = gr.Dropdown(
                    label="Segmentation Model",
                    choices=list(SEGMENTATION_MODELS.keys()),
                    value=list(SEGMENTATION_MODELS.keys())[0]
                )
                seg_model_load_btn = gr.Button("Load Segmentation Model")
                seg_model_status = gr.Textbox(label="Status", value="No model loaded")
            
            with gr.Column():
                depth_model_dropdown = gr.Dropdown(
                    label="Depth Estimation Model",
                    choices=list(DEPTH_MODELS.keys()),
                    value=list(DEPTH_MODELS.keys())[0]
                )
                depth_model_load_btn = gr.Button("Load Depth Model")
                depth_model_status = gr.Textbox(label="Status", value="No model loaded")
    
    with gr.Tab("Image Processing"):
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image", type="pil")
                method = gr.Radio(label="Processing Method", 
                                choices=["Segmented Background Blur", "Depth-based Variable Blur"],
                                value="Segmented Background Blur")
                blur_intensity = gr.Slider(label="Blur Intensity (Maximum Blur Radius)", 
                                        minimum=1, maximum=30, step=1, value=15)
                blur_type = gr.Dropdown(label="Blur Type", 
                                        choices=["Gaussian Blur", "Lens Blur"], 
                                        value="Gaussian Blur")
                run_button = gr.Button("Process Image")
            with gr.Column():
                output_image = gr.Image(label="Output Image")
                mask_output = gr.Image(label="Mask")
    
    # Set up event handlers
    seg_model_load_btn.click(
        fn=load_segmentation_model,
        inputs=[seg_model_dropdown],
        outputs=[seg_model_status]
    )
    
    depth_model_load_btn.click(
        fn=load_depth_model,
        inputs=[depth_model_dropdown],
        outputs=[depth_model_status]
    )
    
    run_button.click(
        fn=process_image, 
        inputs=[input_image, method, blur_intensity, blur_type], 
        outputs=[output_image, mask_output]
    )

    # Load default models on startup
    demo.load(
        fn=lambda: (
            load_segmentation_model(list(SEGMENTATION_MODELS.keys())[0]),
            load_depth_model(list(DEPTH_MODELS.keys())[0])
        ),
        inputs=None,
        outputs=[seg_model_status, depth_model_status]
    )

# Launch the app
demo.launch()