File size: 4,217 Bytes
4200d56
 
f86bc23
411ded6
4200d56
 
f86bc23
4200d56
da8d67c
4200d56
 
da8d67c
4200d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8d67c
 
 
 
 
 
 
f86bc23
da8d67c
4200d56
411ded6
 
 
da8d67c
 
 
 
 
 
 
 
 
411ded6
f86bc23
4200d56
411ded6
4200d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f86bc23
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from transformers import pipeline
from PIL import Image, ImageFilter
import gradio as gr
import torch
import numpy as np
depth_pipe = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")

def compute_depth_map_pipeline(image: Image.Image, scale_factor: float) -> np.ndarray:
    """
    Computes a depth map using the HF pipeline.
    The returned depth is inverted (so near=0 and far=1) and scaled.
    """
    result = depth_pipe(image)[0]
    depth_map = np.array(result["depth"])
    depth_map = 1.0 - depth_map
    depth_map *= scale_factor
    return depth_map

def layered_blur(image: Image.Image, depth_map: np.ndarray, num_layers: int, max_blur: float) -> Image.Image:
    blur_radii = np.linspace(0, max_blur, num_layers)
    blur_versions = [image.filter(ImageFilter.GaussianBlur(r)) for r in blur_radii]
    upper_bound = depth_map.max()
    thresholds = np.linspace(0, upper_bound, num_layers + 1)
    final_image = blur_versions[-1]
    for i in range(num_layers - 1, -1, -1):
        mask_array = np.logical_and(depth_map >= thresholds[i],
                                    depth_map < thresholds[i + 1]).astype(np.uint8) * 255
        mask_image = Image.fromarray(mask_array, mode="L")
        final_image = Image.composite(blur_versions[i], final_image, mask_image)
    return final_image

def process_depth_blur_pipeline(uploaded_image, max_blur_value, scale_factor, num_layers):
    if not isinstance(uploaded_image, Image.Image):
        uploaded_image = Image.open(uploaded_image)
    image = uploaded_image.convert("RGB").resize((512, 512))
    depth_map = compute_depth_map_pipeline(image, scale_factor)
    final_image = layered_blur(image, depth_map, int(num_layers), max_blur_value)
    return final_image

# --- Segmentation-Based Blur using BEN2 ---
def load_segmentation_model():
    global seg_model, seg_device
    if "seg_model" not in globals():
        from ben2 import BEN_Base  # Import BEN2
        seg_model = BEN_Base.from_pretrained("PramaLLC/BEN2")
        seg_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        seg_model.to(seg_device).eval()
    return seg_model, seg_device

def process_segmentation_blur(uploaded_image, seg_blur_radius: float):

    if not isinstance(uploaded_image, Image.Image):
        uploaded_image = Image.open(uploaded_image)
    image = uploaded_image.convert("RGB").resize((512, 512))
    seg_model, seg_device = load_segmentation_model()
    blurred_image = image.filter(ImageFilter.GaussianBlur(seg_blur_radius))
    
    # Generate segmentation mask (foreground)
    foreground = seg_model.inference(image, refine_foreground=False)
    foreground_rgba = foreground.convert("RGBA")
    _, _, _, alpha = foreground_rgba.split()
    binary_mask = alpha.point(lambda x: 255 if x > 128 else 0, mode="L")
    final_image = Image.composite(image, blurred_image, binary_mask)
    return final_image

# --- Merged Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown("# Lens Blur & Gaussian Blur")
    with gr.Tabs():
        with gr.Tab("Lens Blur"):
            depth_img = gr.Image(type="pil", label="Upload Image")
            depth_max_blur = gr.Slider(1.0, 5.0, value=3.0, step=0.1, label="Maximum Blur Radius")
            depth_scale = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Depth Scale Factor")
            depth_layers = gr.Slider(2, 20, value=8, step=1, label="Number of Layers")
            depth_out = gr.Image(label="Lens Blurred Image")
            depth_button = gr.Button("Process Lens Blur")
            depth_button.click(process_depth_blur_pipeline, 
                               inputs=[depth_img, depth_max_blur, depth_scale, depth_layers],
                               outputs=depth_out)
        with gr.Tab("Guassian Blur"):
            seg_img = gr.Image(type="pil", label="Upload Image")
            seg_blur = gr.Slider(5, 30, value=15, step=1, label="Segmentation Blur Radius")
            seg_out = gr.Image(label="Gaussian Blurred Image")
            seg_button = gr.Button("Gaussian Blur")
            seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out)

if __name__ == "__main__":
    demo.launch()