PDK32 commited on
Commit
8b6a896
·
verified ·
1 Parent(s): 685adfb

Upload 2 files

Browse files

Add Gradio app and requirements

Files changed (2) hide show
  1. app.py +302 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
7
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ # Load segmentation model - using SegFormer which is compatible with AutoModelForSemanticSegmentation
12
+ seg_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
13
+ seg_model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
14
+
15
+ # Load depth estimation model
16
+ depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
17
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
18
+
19
+ def safe_resize(image, target_size, interpolation=cv2.INTER_LINEAR):
20
+ """Safely resize an image with validation checks."""
21
+ if image is None:
22
+ return None
23
+
24
+ # Ensure image is a proper numpy array
25
+ if not isinstance(image, np.ndarray):
26
+ return None
27
+
28
+ # Check that dimensions are valid (non-zero)
29
+ h, w = target_size
30
+ if h <= 0 or w <= 0 or image.shape[0] <= 0 or image.shape[1] <= 0:
31
+ return image # Return original if target dimensions are invalid
32
+
33
+ # Handle grayscale images differently
34
+ if len(image.shape) == 2:
35
+ return cv2.resize(image, (w, h), interpolation=interpolation)
36
+ else:
37
+ return cv2.resize(image, (w, h), interpolation=interpolation)
38
+
39
+ def apply_gaussian_blur(image, mask, sigma=15):
40
+ """Apply Gaussian blur to the background of an image based on a mask."""
41
+ try:
42
+ # Convert mask to binary (0 and 255)
43
+ if mask.max() <= 1.0:
44
+ binary_mask = (mask * 255).astype(np.uint8)
45
+ else:
46
+ binary_mask = mask.astype(np.uint8)
47
+
48
+ # Create a blurred version of the entire image
49
+ blurred = cv2.GaussianBlur(image, (0, 0), sigma)
50
+
51
+ # Resize mask to match image dimensions if needed
52
+ if binary_mask.shape[:2] != image.shape[:2]:
53
+ binary_mask = safe_resize(binary_mask, (image.shape[0], image.shape[1]))
54
+
55
+ # Create a 3-channel mask if the input mask is single-channel
56
+ if len(binary_mask.shape) == 2:
57
+ mask_3ch = np.stack([binary_mask, binary_mask, binary_mask], axis=2)
58
+ else:
59
+ mask_3ch = binary_mask
60
+
61
+ # Normalize mask to range [0, 1]
62
+ mask_3ch = mask_3ch / 255.0
63
+
64
+ # Combine original image (foreground) with blurred image (background) using the mask
65
+ result = image * mask_3ch + blurred * (1 - mask_3ch)
66
+
67
+ return result.astype(np.uint8)
68
+ except Exception as e:
69
+ print(f"Error in apply_gaussian_blur: {e}")
70
+ return image # Return original image if there's an error
71
+
72
+ def apply_depth_blur(image, depth_map, max_sigma=25):
73
+ """Apply variable Gaussian blur based on depth map."""
74
+ try:
75
+ # Normalize depth map to range [0, 1]
76
+ if depth_map.max() > 1.0:
77
+ depth_norm = depth_map / depth_map.max()
78
+ else:
79
+ depth_norm = depth_map
80
+
81
+ # Resize depth map to match image dimensions if needed
82
+ if depth_norm.shape[:2] != image.shape[:2]:
83
+ depth_norm = safe_resize(depth_norm, (image.shape[0], image.shape[1]))
84
+
85
+ # Create output image
86
+ result = np.zeros_like(image)
87
+
88
+ # Instead of many small blurs, use fewer blur levels for efficiency
89
+ blur_levels = 5
90
+ step = max_sigma / blur_levels
91
+
92
+ for i in range(blur_levels):
93
+ sigma = (i + 1) * step
94
+
95
+ # Calculate depth range for this blur level
96
+ lower_bound = i / blur_levels
97
+ upper_bound = (i + 1) / blur_levels
98
+
99
+ # Create mask for pixels in this depth range
100
+ mask = np.logical_and(depth_norm >= lower_bound, depth_norm <= upper_bound).astype(np.float32)
101
+
102
+ # Skip if no pixels in this range
103
+ if not np.any(mask):
104
+ continue
105
+
106
+ # Apply blur for this level
107
+ blurred = cv2.GaussianBlur(image, (0, 0), sigma)
108
+
109
+ # Create 3-channel mask
110
+ mask_3ch = np.stack([mask, mask, mask], axis=2) if len(mask.shape) == 2 else mask
111
+
112
+ # Add to result
113
+ result += (blurred * mask_3ch).astype(np.uint8)
114
+
115
+ # Check if there are any pixels not covered and fill with original
116
+ total_mask = np.zeros_like(depth_norm)
117
+ for i in range(blur_levels):
118
+ lower_bound = i / blur_levels
119
+ upper_bound = (i + 1) / blur_levels
120
+ mask = np.logical_and(depth_norm >= lower_bound, depth_norm <= upper_bound).astype(np.float32)
121
+ total_mask += mask
122
+
123
+ missing_mask = (total_mask < 0.5).astype(np.float32)
124
+ if np.any(missing_mask):
125
+ missing_mask_3ch = np.stack([missing_mask, missing_mask, missing_mask], axis=2)
126
+ result += (image * missing_mask_3ch).astype(np.uint8)
127
+
128
+ return result
129
+ except Exception as e:
130
+ print(f"Error in apply_depth_blur: {e}")
131
+ return image # Return original image if there's an error
132
+
133
+ def get_segmentation_mask(image_pil):
134
+ """Get segmentation mask for person/foreground from an image."""
135
+ try:
136
+ # Process the image with the segmentation model
137
+ inputs = seg_processor(images=image_pil, return_tensors="pt")
138
+ with torch.no_grad():
139
+ outputs = seg_model(**inputs)
140
+
141
+ # Get the predicted segmentation mask
142
+ logits = outputs.logits
143
+ upsampled_logits = torch.nn.functional.interpolate(
144
+ logits,
145
+ size=image_pil.size[::-1], # Resize directly to original size
146
+ mode="bilinear",
147
+ align_corners=False,
148
+ )
149
+
150
+ # Get the predicted class for each pixel
151
+ predicted_mask = upsampled_logits.argmax(dim=1)[0]
152
+
153
+ # Convert the mask to a numpy array
154
+ mask_np = predicted_mask.cpu().numpy()
155
+
156
+ # Create a foreground mask - human and common foreground objects
157
+ # Classes based on ADE20K dataset
158
+ foreground_classes = [12] # Person class (you can add more classes as needed)
159
+
160
+ # Create a binary mask for foreground classes
161
+ foreground_mask = np.zeros_like(mask_np)
162
+ for cls in foreground_classes:
163
+ foreground_mask[mask_np == cls] = 1
164
+
165
+ return foreground_mask
166
+ except Exception as e:
167
+ print(f"Error in get_segmentation_mask: {e}")
168
+ # Return a default mask (all ones) in case of error
169
+ return np.ones((image_pil.size[1], image_pil.size[0]), dtype=np.uint8)
170
+
171
+ def get_depth_map(image_pil):
172
+ """Get depth map from an image."""
173
+ try:
174
+ # Process the image with the depth estimation model
175
+ inputs = depth_processor(images=image_pil, return_tensors="pt")
176
+ with torch.no_grad():
177
+ outputs = depth_model(**inputs)
178
+ predicted_depth = outputs.predicted_depth
179
+
180
+ # Interpolate to original size
181
+ prediction = torch.nn.functional.interpolate(
182
+ predicted_depth.unsqueeze(1),
183
+ size=image_pil.size[::-1],
184
+ mode="bicubic",
185
+ align_corners=False,
186
+ )
187
+
188
+ # Convert to numpy array
189
+ depth_map = prediction.squeeze().cpu().numpy()
190
+
191
+ # Normalize depth map
192
+ depth_min = depth_map.min()
193
+ depth_max = depth_map.max()
194
+ if depth_max > depth_min:
195
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
196
+ else:
197
+ depth_map = np.zeros_like(depth_map)
198
+
199
+ return depth_map
200
+ except Exception as e:
201
+ print(f"Error in get_depth_map: {e}")
202
+ # Return a default depth map (gradient from top to bottom) in case of error
203
+ h, w = image_pil.size[1], image_pil.size[0]
204
+ default_depth = np.zeros((h, w), dtype=np.float32)
205
+ for i in range(h):
206
+ default_depth[i, :] = i / h
207
+ return default_depth
208
+
209
+ def process_image(input_image, blur_sigma=15, depth_blur_sigma=25):
210
+ """Main function to process the input image."""
211
+ try:
212
+ # Input validation
213
+ if input_image is None:
214
+ print("No input image provided")
215
+ return [None, None, None, None, None]
216
+
217
+ # Convert to PIL Image if needed
218
+ if isinstance(input_image, np.ndarray):
219
+ # Make sure we have a valid image with at least 2 dimensions
220
+ if input_image.ndim < 2 or input_image.shape[0] <= 0 or input_image.shape[1] <= 0:
221
+ print("Invalid input image dimensions")
222
+ return [None, None, None, None, None]
223
+ pil_image = Image.fromarray(input_image)
224
+ else:
225
+ pil_image = input_image
226
+ input_image = np.array(pil_image)
227
+
228
+ # Get segmentation mask
229
+ print("Getting segmentation mask...")
230
+ seg_mask = get_segmentation_mask(pil_image)
231
+
232
+ # Get depth map
233
+ print("Getting depth map...")
234
+ depth_map = get_depth_map(pil_image)
235
+
236
+ # Apply gaussian blur to background
237
+ print("Applying gaussian blur...")
238
+ gaussian_result = apply_gaussian_blur(input_image, seg_mask, sigma=blur_sigma)
239
+
240
+ # Apply depth-based blur
241
+ print("Applying depth-based blur...")
242
+ depth_result = apply_depth_blur(input_image, depth_map, max_sigma=depth_blur_sigma)
243
+
244
+ # Display depth map as an image
245
+ depth_visualization = (depth_map * 255).astype(np.uint8)
246
+ depth_colored = cv2.applyColorMap(depth_visualization, cv2.COLORMAP_INFERNO)
247
+
248
+ # Display segmentation mask
249
+ seg_visualization = (seg_mask * 255).astype(np.uint8)
250
+
251
+ print("Processing complete!")
252
+ return [
253
+ input_image,
254
+ seg_visualization,
255
+ gaussian_result,
256
+ depth_colored,
257
+ depth_result
258
+ ]
259
+ except Exception as e:
260
+ print(f"Error processing image: {e}")
261
+ return [None, None, None, None, None]
262
+
263
+ # Create Gradio interface
264
+ with gr.Blocks(title="Image Blur Effects with Segmentation and Depth Estimation") as demo:
265
+ gr.Markdown("# Image Blur Effects App")
266
+ gr.Markdown("This app demonstrates two types of blur effects: background blur using segmentation and depth-based lens blur.")
267
+
268
+ with gr.Row():
269
+ with gr.Column():
270
+ input_image = gr.Image(label="Upload an image", type="numpy")
271
+ blur_sigma = gr.Slider(minimum=1, maximum=50, value=15, step=1, label="Background Blur Intensity")
272
+ depth_blur_sigma = gr.Slider(minimum=1, maximum=50, value=25, step=1, label="Depth Blur Max Intensity")
273
+ process_btn = gr.Button("Process Image")
274
+
275
+ with gr.Column():
276
+ with gr.Tab("Original Image"):
277
+ output_original = gr.Image(label="Original Image")
278
+ with gr.Tab("Segmentation Mask"):
279
+ output_segmentation = gr.Image(label="Segmentation Mask")
280
+ with gr.Tab("Background Blur"):
281
+ output_gaussian = gr.Image(label="Background Blur Result")
282
+ with gr.Tab("Depth Map"):
283
+ output_depth = gr.Image(label="Depth Map")
284
+ with gr.Tab("Depth-based Lens Blur"):
285
+ output_depth_blur = gr.Image(label="Depth-based Lens Blur Result")
286
+
287
+ process_btn.click(
288
+ fn=process_image,
289
+ inputs=[input_image, blur_sigma, depth_blur_sigma],
290
+ outputs=[output_original, output_segmentation, output_gaussian, output_depth, output_depth_blur]
291
+ )
292
+
293
+ gr.Markdown("""
294
+ ## How it works
295
+
296
+ 1. **Background Blur**: Uses a SegFormer model to identify foreground objects (like people) and blurs only the background
297
+ 2. **Depth-based Lens Blur**: Uses a DPT depth estimation model to apply variable blur based on estimated distance
298
+
299
+ Try uploading a photo of a person against a background to see the effects!
300
+ """)
301
+
302
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=3.50.2
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ pillow>=9.0.0
5
+ numpy>=1.24.0
6
+ opencv-python>=4.7.0