File size: 5,835 Bytes
7622507
a4f7533
d95d803
cf3d8bd
 
b071b47
 
cf3d8bd
 
10a62a2
cf3d8bd
 
10a62a2
5de99ac
 
 
 
 
6696c1a
0d4e90d
 
cf3d8bd
 
 
 
 
 
0d4e90d
cf3d8bd
 
 
 
 
 
 
 
6696c1a
cf3d8bd
 
 
 
 
 
 
 
 
 
 
 
 
 
028e2a2
cf3d8bd
 
6696c1a
028e2a2
 
cf3d8bd
 
 
6696c1a
cf3d8bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
028e2a2
cf3d8bd
 
 
6696c1a
cf3d8bd
 
 
 
 
 
 
 
 
 
 
 
10a62a2
 
028e2a2
cf3d8bd
 
028e2a2
cf3d8bd
 
 
 
 
 
 
 
 
 
 
 
 
bed7b9c
 
 
 
cf3d8bd
bed7b9c
10a62a2
73a8aca
0a829d8
0369d29
10a62a2
 
 
 
af82372
10a62a2
 
0a829d8
0369d29
af82372
bed7b9c
0369d29
 
bed7b9c
 
0d4e90d
 
 
bed7b9c
0d4e90d
10a62a2
0369d29
 
0a829d8
 
 
 
 
bed7b9c
38038ee
f6af197
7622507
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import streamlit as st
st.set_page_config(page_title="Blur App", page_icon=":camera_with_flash:")

from PIL import Image, ImageFilter
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
import numpy as np

# pointless since no GPU access
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


st.title("Blur an image!")
st.subheader("Upload your image and choose a blur style")
st.text("A background blur uses segmentation to completely blur all background objects.\n A lens blur applies gradual blur using depth estimation.")
st.divider()

st.warning("**Note**: The lens blur option takes a long time to process (>5min) since this space isn't linked to a GPU.")

@st.cache_resource(show_spinner="Pushing pixels...")
def load_gblur_model():
    birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
    birefnet.to(device)
    birefnet.eval()
    return birefnet

@st.cache_resource(show_spinner="Running with scissors...")
def load_lblur_model():
    image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
    model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
    return model, image_processor

gblur_model = load_gblur_model()
lblur_model, lblur_img_proc = load_lblur_model()

def gaussian_blur(image, blur_str):   
    # Image transform
    image_size = (512, 512)
    transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Prediction
    input_image = transform_image(image).unsqueeze(0).to(device)
    with torch.no_grad():
        preds = gblur_model(input_image)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = np.array(pred_pil.resize(image.size))

    # Blurring
    blur = np.array(image.filter(ImageFilter.GaussianBlur(radius=blur_str)))
    mask = np.expand_dims(mask, axis=2)
    output_image = np.where(mask, np.array(image), blur)
    
    return Image.fromarray(output_image)

def lens_blur(image, blur_str):
    # Process image
    inputs = lblur_img_proc(images=image, return_tensors="pt").to(device)

    # Perform forward pass
    with torch.no_grad():
        outputs = lblur_model(**inputs)
    
    post_processed_output = lblur_img_proc.post_process_depth_estimation(
        outputs, target_sizes=[(image.height, image.width)],
    )

    # Get depth map
    depth = post_processed_output[0]["predicted_depth"]
    depth = (depth - depth.min()) / (depth.max() - depth.min())
    depth = depth * 255.
    depth = depth.detach().cpu().numpy()

    # Normalize map
    depth = depth / 255.0
    
    # No of discrete blurs and max blur intensity
    num_levels = 15
    max_radius = blur_str
    
    # Pre-compute all blur images
    blurred_images = []
    for i in range(num_levels):
        radius = (i / (num_levels - 1)) * max_radius
        if radius < 0.1:
            blurred_images.append(np.array(image))
        else:
            blurred = np.array(image.filter(ImageFilter.GaussianBlur(radius)))
            blurred_images.append(blurred)
    blurred_stack = np.stack(blurred_images, axis=0)
    
    # Blend together the images using 
    # Bilinear Interpolation of depth levels
    h, w = depth.shape
    y_coords, x_coords = np.indices((h, w))
    
    depth_levels = depth * (num_levels - 1)
    low_levels = np.floor(depth_levels).astype(int)
    high_levels = np.clip(low_levels + 1, 0, num_levels - 1)
    alpha = depth_levels - low_levels
    
    pixel_low = blurred_stack[low_levels, y_coords, x_coords, :].astype(np.float32)
    pixel_high = blurred_stack[high_levels, y_coords, x_coords, :].astype(np.float32)
    
    output = (1 - alpha)[..., np.newaxis] * pixel_low + alpha[..., np.newaxis] * pixel_high
    
    # Final blurred image
    output_img = Image.fromarray(np.clip(output, 0, 255).astype(np.uint8))
    return output_img

left, right = st.columns(2)
with left:
    st.header("Upload your image")
    up_img = st.file_uploader("Upload image", type=["jpg", "jpeg", "png", "dng", "tiff"])

with right:
    st.header("Set blur settings.")
    st.text("Choose a blur style, set the blur strength, and hit 'Apply'!")
    with st.form("Blur_form"):
        options = ["Background", "Lens"]
        selection = st.radio(
            "Choose a blur type:", 
            options, index=None,
            captions=[
                "Use segmentation", 
                "Use depth estimation"], 
            horizontal=True
        )
        blur_level = st.select_slider("Blur Strength", options=["very low","low","medium","high","very high"])
        submitted = st.form_submit_button("Apply Blur", disabled=(up_img is None), type="primary")
st.divider()

blur_levels = {"very low":5, "low":10, "medium":15, "high":20, "very high":30}
if up_img:
    image = Image.open(up_img).convert("RGB")
    disp_left, disp_right = st.columns(2)
    with disp_left:
        og_img = image.copy()
        st.image(og_img, caption="Original Image", width=300)
    with disp_right:
        if submitted and selection in options:
            blur_str = blur_levels[blur_level]
            if selection == "Background":
                with st.spinner(f"Spinning violently around the y-axis..."):
                    result = gaussian_blur(image, blur_str)
            elif selection == "Lens":
                with st.spinner(f"One mississippi, two mississippi..."):
                    result = lens_blur(image, blur_str)
            st.image(result, "Blurred Image", width=300)
        else:
            st.write("Waiting for you to select a blur type...")