joeWabbit commited on
Commit
da8d67c
·
verified ·
1 Parent(s): d56029f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -76
app.py CHANGED
@@ -1,92 +1,47 @@
1
  import gradio as gr
2
  import torch
3
- import numpy as np
4
- from transformers import AutoImageProcessor, AutoModelForDepthEstimation
5
  from PIL import Image, ImageFilter
6
 
7
- def load_depth_model():
8
- """
9
- Loads the depth estimation model and processor.
10
- Returns (processor, model, device).
11
- """
12
- global processor, model, device
13
- if "model" not in globals():
14
- processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2")
15
- model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2")
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- model.to(device).eval()
18
- return processor, model, device
19
-
20
- def compute_depth_map(image: Image.Image, scale_factor: float) -> np.ndarray:
21
- """
22
- Computes the depth map for a PIL image.
23
- Inverts the map (i.e. force invert_depth=True) and scales it.
24
- Returns a NumPy array in [0, 1]*scale_factor.
25
- """
26
- processor, model, device = load_depth_model()
27
- inputs = processor(images=image, return_tensors="pt").to(device)
28
- with torch.no_grad():
29
- outputs = model(**inputs)
30
- predicted_depth = outputs.predicted_depth
31
-
32
- prediction = torch.nn.functional.interpolate(
33
- predicted_depth.unsqueeze(1),
34
- size=image.size[::-1], # PIL image size: (width, height)
35
- mode="bicubic",
36
- align_corners=False,
37
- )
38
- depth_min = prediction.min()
39
- depth_max = prediction.max()
40
- depth_vis = (prediction - depth_min) / (depth_max - depth_min + 1e-8)
41
- depth_map = depth_vis.squeeze().cpu().numpy()
42
- # Always invert depth so that near=0 and far=1
43
- depth_map = 1.0 - depth_map
44
- depth_map *= scale_factor
45
- return depth_map
46
-
47
- def layered_blur(image: Image.Image, depth_map: np.ndarray, num_layers: int, max_blur: float) -> Image.Image:
48
- """
49
- Creates multiple blurred versions of 'image' (radii from 0 to max_blur)
50
- and composites them based on the depth map split into num_layers bins.
51
- """
52
- blur_radii = np.linspace(0, max_blur, num_layers)
53
- blur_versions = [image.filter(ImageFilter.GaussianBlur(r)) for r in blur_radii]
54
- upper_bound = depth_map.max()
55
- thresholds = np.linspace(0, upper_bound, num_layers + 1)
56
- final_image = blur_versions[-1]
57
- for i in range(num_layers - 1, -1, -1):
58
- mask_array = np.logical_and(
59
- depth_map >= thresholds[i],
60
- depth_map < thresholds[i + 1]
61
- ).astype(np.uint8) * 255
62
- mask_image = Image.fromarray(mask_array, mode="L")
63
- final_image = Image.composite(blur_versions[i], final_image, mask_image)
64
- return final_image
65
 
66
- def process_depth_blur(uploaded_image, max_blur_value, scale_factor, num_layers):
67
  """
68
- Processes the image with a depth-based blur.
69
- The image is resized to 512x512, its depth is computed (with invert_depth always True),
70
- and a layered blur is applied.
71
  """
72
  if not isinstance(uploaded_image, Image.Image):
73
  uploaded_image = Image.open(uploaded_image)
74
  image = uploaded_image.convert("RGB").resize((512, 512))
75
- depth_map = compute_depth_map(image, scale_factor)
76
- final_image = layered_blur(image, depth_map, int(num_layers), max_blur_value)
 
 
 
 
 
 
 
77
  return final_image
78
 
79
  with gr.Blocks() as demo:
80
- gr.Markdown("# Depth-Based Lens Blur")
81
- depth_img = gr.Image(type="pil", label="Upload Image")
82
- depth_max_blur = gr.Slider(1.0, 5.0, value=3.0, step=0.1, label="Maximum Blur Radius")
83
- depth_scale = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Depth Scale Factor")
84
- depth_layers = gr.Slider(2, 20, value=8, step=1, label="Number of Layers")
85
- depth_out = gr.Image(label="Depth-Based Blurred Image")
86
- depth_button = gr.Button("Process Depth Blur")
87
- depth_button.click(process_depth_blur,
88
- inputs=[depth_img, depth_max_blur, depth_scale, depth_layers],
89
- outputs=depth_out)
90
 
91
  if __name__ == "__main__":
92
  demo.launch()
 
1
  import gradio as gr
2
  import torch
 
 
3
  from PIL import Image, ImageFilter
4
 
5
+ def load_segmentation_model():
6
+ """
7
+ Loads and caches the segmentation model from BEN2.
8
+ Ensure you have ben2 installed and accessible in your path.
9
+ """
10
+ global seg_model, seg_device
11
+ if "seg_model" not in globals():
12
+ from ben2 import BEN_Base # Import BEN2
13
+ seg_model = BEN_Base.from_pretrained("PramaLLC/BEN2")
14
+ seg_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ seg_model.to(seg_device).eval()
16
+ return seg_model, seg_device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def process_segmentation_blur(uploaded_image, seg_blur_radius: float):
19
  """
20
+ Processes the image with segmentation-based blur.
21
+ The image is resized to 512x512. A Gaussian blur with the specified radius is applied,
22
+ then the segmentation mask is computed to composite the sharp foreground over the blurred background.
23
  """
24
  if not isinstance(uploaded_image, Image.Image):
25
  uploaded_image = Image.open(uploaded_image)
26
  image = uploaded_image.convert("RGB").resize((512, 512))
27
+ seg_model, seg_device = load_segmentation_model()
28
+ blurred_image = image.filter(ImageFilter.GaussianBlur(seg_blur_radius))
29
+
30
+ # Generate segmentation mask (foreground)
31
+ foreground = seg_model.inference(image, refine_foreground=False)
32
+ foreground_rgba = foreground.convert("RGBA")
33
+ _, _, _, alpha = foreground_rgba.split()
34
+ binary_mask = alpha.point(lambda x: 255 if x > 128 else 0, mode="L")
35
+ final_image = Image.composite(image, blurred_image, binary_mask)
36
  return final_image
37
 
38
  with gr.Blocks() as demo:
39
+ gr.Markdown("# Segmentation-Based Blur using BEN2")
40
+ seg_img = gr.Image(type="pil", label="Upload Image")
41
+ seg_blur = gr.Slider(5, 30, value=15, step=1, label="Segmentation Blur Radius")
42
+ seg_out = gr.Image(label="Segmentation-Based Blurred Image")
43
+ seg_button = gr.Button("Process Segmentation Blur")
44
+ seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out)
 
 
 
 
45
 
46
  if __name__ == "__main__":
47
  demo.launch()