yrosenbloom commited on
Commit
ac60056
Β·
verified Β·
1 Parent(s): 8856f36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -50
app.py CHANGED
@@ -1,78 +1,76 @@
1
  import gradio as gr
2
- from PIL import Image, ImageFilter, ImageOps
3
  import numpy as np
4
  import torch
5
- from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation, DPTFeatureExtractor, DPTForDepthEstimation
 
 
 
6
  import cv2
 
7
 
8
- # Load segmentation model
 
9
  seg_model_name = "nvidia/segformer-b1-finetuned-ade-512-512"
10
- seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(seg_model_name)
11
  seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_name)
12
 
13
- # Load depth estimation model
14
  depth_model_name = "Intel/dpt-hybrid-midas"
15
- depth_feature_extractor = DPTFeatureExtractor.from_pretrained(depth_model_name)
16
  depth_model = DPTForDepthEstimation.from_pretrained(depth_model_name)
17
 
18
- # Device configuration
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  seg_model.to(device)
21
  depth_model.to(device)
22
 
23
- def process_image(image):
24
- # Ensure image is in RGB format and resize
25
- image = image.resize((512, 512))
26
-
27
- # Perform segmentation
28
- inputs = seg_feature_extractor(images=image, return_tensors="pt").to(device)
29
  with torch.no_grad():
30
- outputs = seg_model(**inputs)
31
- logits = outputs.logits
32
- segmentation = torch.argmax(logits, dim=1)[0].cpu().numpy()
33
- binary_mask = np.where(segmentation > 0, 255, 0).astype(np.uint8)
34
 
35
- # Apply Gaussian Blur to the background
36
- blurred_background = image.filter(ImageFilter.GaussianBlur(15))
37
- foreground = Image.fromarray(binary_mask).convert("L").resize(image.size)
38
- output_blur = Image.composite(image, blurred_background, foreground)
39
 
40
- # Depth estimation for lens blur
41
- depth_inputs = depth_feature_extractor(images=image, return_tensors="pt").to(device)
42
  with torch.no_grad():
43
- depth_outputs = depth_model(**depth_inputs)
44
- predicted_depth = depth_outputs.predicted_depth.squeeze().cpu().numpy()
45
-
46
- # Normalize depth map
47
- depth_min, depth_max = predicted_depth.min(), predicted_depth.max()
48
- normalized_depth = (predicted_depth - depth_min) / (depth_max - depth_min)
49
- normalized_depth_resized = cv2.resize(normalized_depth, (512, 512))
50
 
51
- # Lens blur using depth map
52
- blurred_image = np.array(image).astype(np.float32)
53
- blur_intensity = normalized_depth_resized * 20
54
- for y in range(image.size[1]):
55
- for x in range(image.size[0]):
56
- sigma = blur_intensity[y, x]
57
- kernel_size = int(2 * sigma + 1)
58
- if kernel_size > 1:
59
- patch = image.crop((x - kernel_size//2, y - kernel_size//2, x + kernel_size//2 + 1, y + kernel_size//2 + 1))
60
- patch = patch.filter(ImageFilter.GaussianBlur(sigma))
61
- blurred_image[y, x, :] = np.array(patch)[kernel_size//2, kernel_size//2, :]
62
- lens_blur_image = Image.fromarray(np.clip(blurred_image, 0, 255).astype(np.uint8))
63
 
64
- return image, output_blur, lens_blur_image
65
 
66
  iface = gr.Interface(
67
  fn=process_image,
68
- inputs=gr.Image(type="pil", label="Upload an Image"),
69
  outputs=[
70
- gr.Image(label="Original Image"),
71
- gr.Image(label="Gaussian Blur Effect"),
72
- gr.Image(label="Depth-Based Lens Blur Effect")
73
  ],
74
- title="Image Blurring with Gaussian and Depth-Based Lens Blur",
75
- description="Upload an image to see Gaussian blur and depth-based lens blur effects."
76
  )
77
 
78
- iface.launch()
 
 
1
  import gradio as gr
2
+ from PIL import Image, ImageFilter
3
  import numpy as np
4
  import torch
5
+ from transformers import (
6
+ SegformerFeatureExtractor, SegformerForSemanticSegmentation,
7
+ DPTFeatureExtractor, DPTForDepthEstimation
8
+ )
9
  import cv2
10
+ import os, json
11
 
12
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
13
+ # load segmentation model
14
  seg_model_name = "nvidia/segformer-b1-finetuned-ade-512-512"
15
+ seg_fe = SegformerFeatureExtractor.from_pretrained(seg_model_name)
16
  seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_name)
17
 
18
+ # load depth model
19
  depth_model_name = "Intel/dpt-hybrid-midas"
20
+ depth_fe = DPTFeatureExtractor.from_pretrained(depth_model_name)
21
  depth_model = DPTForDepthEstimation.from_pretrained(depth_model_name)
22
 
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  seg_model.to(device)
25
  depth_model.to(device)
26
 
27
+ def process_image(image: Image.Image):
28
+ # 1) prep
29
+ image = image.convert("RGB").resize((512,512))
30
+
31
+ # 2) segmentation β†’ binary mask
32
+ seg_inputs = seg_fe(images=image, return_tensors="pt").to(device)
33
  with torch.no_grad():
34
+ seg_logits = seg_model(**seg_inputs).logits
35
+ seg_map = torch.argmax(seg_logits, dim=1)[0].cpu().numpy()
36
+ mask = (seg_map > 0).astype(np.uint8) * 255
37
+ mask = Image.fromarray(mask).resize((512,512))
38
 
39
+ # 3) gaussian-blur background
40
+ bg_blur = image.filter(ImageFilter.GaussianBlur(15))
41
+ output_blur = Image.composite(image, bg_blur, mask)
 
42
 
43
+ # 4) depth estimation
44
+ depth_inputs = depth_fe(images=image, return_tensors="pt").to(device)
45
  with torch.no_grad():
46
+ depth_pred = depth_model(**depth_inputs).predicted_depth.squeeze().cpu().numpy()
47
+ # normalize & resize
48
+ dmin, dmax = depth_pred.min(), depth_pred.max()
49
+ depth_norm = (depth_pred - dmin) / (dmax - dmin + 1e-8)
50
+ depth_norm = cv2.resize(depth_norm, (512,512))
 
 
51
 
52
+ # 5) vectorized depth-based blur
53
+ img_np = np.array(image).astype(np.float32)
54
+ # two extremes
55
+ near_blur = cv2.GaussianBlur(img_np, (21,21), 5)
56
+ far_blur = cv2.GaussianBlur(img_np, (81,81), 20)
57
+ alpha = depth_norm[...,None]
58
+ combined = near_blur * (1 - alpha) + far_blur * alpha
59
+ lens_blur = Image.fromarray(np.clip(combined,0,255).astype(np.uint8))
 
 
 
 
60
 
61
+ return image, output_blur, lens_blur
62
 
63
  iface = gr.Interface(
64
  fn=process_image,
65
+ inputs=gr.Image(type="pil", label="Upload Image"),
66
  outputs=[
67
+ gr.Image(type="pil", label="Original"),
68
+ gr.Image(type="pil", label="Gaussian Blur"),
69
+ gr.Image(type="pil", label="Depth-Based Lens Blur"),
70
  ],
71
+ title="Image Blurring with CLAHE + Depth-Based Blur",
72
+ description="Upload a selfie to see background blur and depth-based lens blur."
73
  )
74
 
75
+ if __name__ == "__main__":
76
+ iface.launch(share=True)