pavank007 commited on
Commit
bc24544
·
verified ·
1 Parent(s): ecb0ce5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +301 -0
app.py CHANGED
@@ -234,7 +234,308 @@ with gr.Blocks(title="Image Blur Effects with Segmentation and Depth Estimation"
234
 
235
  gr.Markdown("""
236
  ## How it works
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
 
 
 
 
 
 
 
 
 
 
238
  1. **Background Blur**: Uses a segmentation model to identify foreground objects and blurs only the background
239
  2. **Depth-based Lens Blur**: Uses a depth estimation model to apply variable blur based on estimated distance
240
 
 
234
 
235
  gr.Markdown("""
236
  ## How it works
237
+ import gradio as gr
238
+ import torch
239
+ import numpy as np
240
+ import cv2
241
+ from PIL import Image
242
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
243
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
244
+ import warnings
245
+ warnings.filterwarnings("ignore")
246
+
247
+ # Load segmentation model - using SegFormer which is compatible with AutoModelForSemanticSegmentation
248
+ seg_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
249
+ seg_model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
250
+
251
+ # Load depth estimation model
252
+ depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
253
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
254
+
255
+ def safe_resize(image, target_size, interpolation=cv2.INTER_LINEAR):
256
+ """Safely resize an image with validation checks."""
257
+ if image is None:
258
+ return None
259
+
260
+ # Ensure image is a proper numpy array
261
+ if not isinstance(image, np.ndarray):
262
+ return None
263
+
264
+ # Check that dimensions are valid (non-zero)
265
+ h, w = target_size
266
+ if h <= 0 or w <= 0 or image.shape[0] <= 0 or image.shape[1] <= 0:
267
+ return image # Return original if target dimensions are invalid
268
+
269
+ # Handle grayscale images differently
270
+ if len(image.shape) == 2:
271
+ return cv2.resize(image, (w, h), interpolation=interpolation)
272
+ else:
273
+ return cv2.resize(image, (w, h), interpolation=interpolation)
274
+
275
+ def apply_gaussian_blur(image, mask, sigma=15):
276
+ """Apply Gaussian blur to the background of an image based on a mask."""
277
+ try:
278
+ # Convert mask to binary (0 and 255)
279
+ if mask.max() <= 1.0:
280
+ binary_mask = (mask * 255).astype(np.uint8)
281
+ else:
282
+ binary_mask = mask.astype(np.uint8)
283
+
284
+ # Create a blurred version of the entire image
285
+ blurred = cv2.GaussianBlur(image, (0, 0), sigma)
286
+
287
+ # Resize mask to match image dimensions if needed
288
+ if binary_mask.shape[:2] != image.shape[:2]:
289
+ binary_mask = safe_resize(binary_mask, (image.shape[0], image.shape[1]))
290
+
291
+ # Create a 3-channel mask if the input mask is single-channel
292
+ if len(binary_mask.shape) == 2:
293
+ mask_3ch = np.stack([binary_mask, binary_mask, binary_mask], axis=2)
294
+ else:
295
+ mask_3ch = binary_mask
296
+
297
+ # Normalize mask to range [0, 1]
298
+ mask_3ch = mask_3ch / 255.0
299
+
300
+ # Combine original image (foreground) with blurred image (background) using the mask
301
+ result = image * mask_3ch + blurred * (1 - mask_3ch)
302
+
303
+ return result.astype(np.uint8)
304
+ except Exception as e:
305
+ print(f"Error in apply_gaussian_blur: {e}")
306
+ return image # Return original image if there's an error
307
+
308
+ def apply_depth_blur(image, depth_map, max_sigma=25):
309
+ """Apply variable Gaussian blur based on depth map."""
310
+ try:
311
+ # Normalize depth map to range [0, 1]
312
+ if depth_map.max() > 1.0:
313
+ depth_norm = depth_map / depth_map.max()
314
+ else:
315
+ depth_norm = depth_map
316
+
317
+ # Resize depth map to match image dimensions if needed
318
+ if depth_norm.shape[:2] != image.shape[:2]:
319
+ depth_norm = safe_resize(depth_norm, (image.shape[0], image.shape[1]))
320
+
321
+ # Create output image
322
+ result = np.zeros_like(image)
323
+
324
+ # Instead of many small blurs, use fewer blur levels for efficiency
325
+ blur_levels = 5
326
+ step = max_sigma / blur_levels
327
+
328
+ for i in range(blur_levels):
329
+ sigma = (i + 1) * step
330
+
331
+ # Calculate depth range for this blur level
332
+ lower_bound = i / blur_levels
333
+ upper_bound = (i + 1) / blur_levels
334
+
335
+ # Create mask for pixels in this depth range
336
+ mask = np.logical_and(depth_norm >= lower_bound, depth_norm <= upper_bound).astype(np.float32)
337
+
338
+ # Skip if no pixels in this range
339
+ if not np.any(mask):
340
+ continue
341
+
342
+ # Apply blur for this level
343
+ blurred = cv2.GaussianBlur(image, (0, 0), sigma)
344
+
345
+ # Create 3-channel mask
346
+ mask_3ch = np.stack([mask, mask, mask], axis=2) if len(mask.shape) == 2 else mask
347
+
348
+ # Add to result
349
+ result += (blurred * mask_3ch).astype(np.uint8)
350
+
351
+ # Check if there are any pixels not covered and fill with original
352
+ total_mask = np.zeros_like(depth_norm)
353
+ for i in range(blur_levels):
354
+ lower_bound = i / blur_levels
355
+ upper_bound = (i + 1) / blur_levels
356
+ mask = np.logical_and(depth_norm >= lower_bound, depth_norm <= upper_bound).astype(np.float32)
357
+ total_mask += mask
358
+
359
+ missing_mask = (total_mask < 0.5).astype(np.float32)
360
+ if np.any(missing_mask):
361
+ missing_mask_3ch = np.stack([missing_mask, missing_mask, missing_mask], axis=2)
362
+ result += (image * missing_mask_3ch).astype(np.uint8)
363
+
364
+ return result
365
+ except Exception as e:
366
+ print(f"Error in apply_depth_blur: {e}")
367
+ return image # Return original image if there's an error
368
+
369
+ def get_segmentation_mask(image_pil):
370
+ """Get segmentation mask for person/foreground from an image."""
371
+ try:
372
+ # Process the image with the segmentation model
373
+ inputs = seg_processor(images=image_pil, return_tensors="pt")
374
+ with torch.no_grad():
375
+ outputs = seg_model(**inputs)
376
+
377
+ # Get the predicted segmentation mask
378
+ logits = outputs.logits
379
+ upsampled_logits = torch.nn.functional.interpolate(
380
+ logits,
381
+ size=image_pil.size[::-1], # Resize directly to original size
382
+ mode="bilinear",
383
+ align_corners=False,
384
+ )
385
+
386
+ # Get the predicted class for each pixel
387
+ predicted_mask = upsampled_logits.argmax(dim=1)[0]
388
+
389
+ # Convert the mask to a numpy array
390
+ mask_np = predicted_mask.cpu().numpy()
391
+
392
+ # Create a foreground mask - human and common foreground objects
393
+ # Classes based on ADE20K dataset
394
+ foreground_classes = [12] # Person class (you can add more classes as needed)
395
+
396
+ # Create a binary mask for foreground classes
397
+ foreground_mask = np.zeros_like(mask_np)
398
+ for cls in foreground_classes:
399
+ foreground_mask[mask_np == cls] = 1
400
+
401
+ return foreground_mask
402
+ except Exception as e:
403
+ print(f"Error in get_segmentation_mask: {e}")
404
+ # Return a default mask (all ones) in case of error
405
+ return np.ones((image_pil.size[1], image_pil.size[0]), dtype=np.uint8)
406
+
407
+ def get_depth_map(image_pil):
408
+ """Get depth map from an image."""
409
+ try:
410
+ # Process the image with the depth estimation model
411
+ inputs = depth_processor(images=image_pil, return_tensors="pt")
412
+ with torch.no_grad():
413
+ outputs = depth_model(**inputs)
414
+ predicted_depth = outputs.predicted_depth
415
+
416
+ # Interpolate to original size
417
+ prediction = torch.nn.functional.interpolate(
418
+ predicted_depth.unsqueeze(1),
419
+ size=image_pil.size[::-1],
420
+ mode="bicubic",
421
+ align_corners=False,
422
+ )
423
+
424
+ # Convert to numpy array
425
+ depth_map = prediction.squeeze().cpu().numpy()
426
+
427
+ # Normalize depth map
428
+ depth_min = depth_map.min()
429
+ depth_max = depth_map.max()
430
+ if depth_max > depth_min:
431
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
432
+ else:
433
+ depth_map = np.zeros_like(depth_map)
434
+
435
+ return depth_map
436
+ except Exception as e:
437
+ print(f"Error in get_depth_map: {e}")
438
+ # Return a default depth map (gradient from top to bottom) in case of error
439
+ h, w = image_pil.size[1], image_pil.size[0]
440
+ default_depth = np.zeros((h, w), dtype=np.float32)
441
+ for i in range(h):
442
+ default_depth[i, :] = i / h
443
+ return default_depth
444
+
445
+ def process_image(input_image, blur_sigma=15, depth_blur_sigma=25):
446
+ """Main function to process the input image."""
447
+ try:
448
+ # Input validation
449
+ if input_image is None:
450
+ print("No input image provided")
451
+ return [None, None, None, None, None]
452
+
453
+ # Convert to PIL Image if needed
454
+ if isinstance(input_image, np.ndarray):
455
+ # Make sure we have a valid image with at least 2 dimensions
456
+ if input_image.ndim < 2 or input_image.shape[0] <= 0 or input_image.shape[1] <= 0:
457
+ print("Invalid input image dimensions")
458
+ return [None, None, None, None, None]
459
+ pil_image = Image.fromarray(input_image)
460
+ else:
461
+ pil_image = input_image
462
+ input_image = np.array(pil_image)
463
+
464
+ # Get segmentation mask
465
+ print("Getting segmentation mask...")
466
+ seg_mask = get_segmentation_mask(pil_image)
467
+
468
+ # Get depth map
469
+ print("Getting depth map...")
470
+ depth_map = get_depth_map(pil_image)
471
+
472
+ # Apply gaussian blur to background
473
+ print("Applying gaussian blur...")
474
+ gaussian_result = apply_gaussian_blur(input_image, seg_mask, sigma=blur_sigma)
475
+
476
+ # Apply depth-based blur
477
+ print("Applying depth-based blur...")
478
+ depth_result = apply_depth_blur(input_image, depth_map, max_sigma=depth_blur_sigma)
479
+
480
+ # Display depth map as an image
481
+ depth_visualization = (depth_map * 255).astype(np.uint8)
482
+ depth_colored = cv2.applyColorMap(depth_visualization, cv2.COLORMAP_INFERNO)
483
+
484
+ # Display segmentation mask
485
+ seg_visualization = (seg_mask * 255).astype(np.uint8)
486
+
487
+ print("Processing complete!")
488
+ return [
489
+ input_image,
490
+ seg_visualization,
491
+ gaussian_result,
492
+ depth_colored,
493
+ depth_result
494
+ ]
495
+ except Exception as e:
496
+ print(f"Error processing image: {e}")
497
+ return [None, None, None, None, None]
498
+
499
+ # Create Gradio interface
500
+ with gr.Blocks(title="Image Blur Effects with Segmentation and Depth Estimation") as demo:
501
+ gr.Markdown("# Image Blur Effects App")
502
+ gr.Markdown("This app demonstrates two types of blur effects: background blur using segmentation and depth-based lens blur.")
503
+
504
+ with gr.Row():
505
+ with gr.Column():
506
+ input_image = gr.Image(label="Upload an image", type="numpy")
507
+ blur_sigma = gr.Slider(minimum=1, maximum=50, value=15, step=1, label="Background Blur Intensity")
508
+ depth_blur_sigma = gr.Slider(minimum=1, maximum=50, value=25, step=1, label="Depth Blur Max Intensity")
509
+ process_btn = gr.Button("Process Image")
510
+
511
+ with gr.Column():
512
+ with gr.Tab("Original Image"):
513
+ output_original = gr.Image(label="Original Image")
514
+ with gr.Tab("Segmentation Mask"):
515
+ output_segmentation = gr.Image(label="Segmentation Mask")
516
+ with gr.Tab("Background Blur"):
517
+ output_gaussian = gr.Image(label="Background Blur Result")
518
+ with gr.Tab("Depth Map"):
519
+ output_depth = gr.Image(label="Depth Map")
520
+ with gr.Tab("Depth-based Lens Blur"):
521
+ output_depth_blur = gr.Image(label="Depth-based Lens Blur Result")
522
+
523
+ process_btn.click(
524
+ fn=process_image,
525
+ inputs=[input_image, blur_sigma, depth_blur_sigma],
526
+ outputs=[output_original, output_segmentation, output_gaussian, output_depth, output_depth_blur]
527
+ )
528
 
529
+ gr.Markdown("""
530
+ ## How it works
531
+
532
+ 1. **Background Blur**: Uses a SegFormer model to identify foreground objects (like people) and blurs only the background
533
+ 2. **Depth-based Lens Blur**: Uses a DPT depth estimation model to apply variable blur based on estimated distance
534
+
535
+ Try uploading a photo of a person against a background to see the effects!
536
+ """)
537
+
538
+ demo.launch()
539
  1. **Background Blur**: Uses a segmentation model to identify foreground objects and blurs only the background
540
  2. **Depth-based Lens Blur**: Uses a depth estimation model to apply variable blur based on estimated distance
541