File size: 5,501 Bytes
f15e528
 
f1c3b2e
 
 
 
6ae5574
1b6d3df
52345c4
f15e528
0c77d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a36420f
0c77d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f15e528
228c14e
 
 
 
f15e528
 
 
 
 
 
 
 
4e2270c
900327c
f15e528
35d4033
 
 
0c77d36
b6d6a39
 
0c77d36
35d4033
23dbd51
228c14e
0c77d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e272d7
0c77d36
 
 
 
 
 
e349594
 
0c77d36
 
 
 
35d4033
 
d18b91f
0c77d36
 
 
 
 
 
 
 
35d4033
0c77d36
35d4033
d18b91f
35d4033
 
 
d18b91f
f4c0f63
35d4033
0c77d36
 
d18b91f
35d4033
f15e528
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import streamlit as st
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from transformers import pipeline
import numpy as np
import os

def depth_based_blur(orig_image: Image.Image, depth_map: Image.Image, max_blur: float = 15,
                     num_bands: int = 10, invert_depth: bool = True) -> Image.Image:
    """
    Apply a depth-based blur effect to the original image with depth map image.

    Returns:
      PIL.Image.Image: The final image with background (farther areas) blurred.
    """
    
    # Convert depth map to a NumPy array (float32) and normalize to [0, 1]
    depth_array = np.array(depth_map, dtype=np.float32)
    d_min, d_max = depth_array.min(), depth_array.max()
    depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
    
    if invert_depth:
        depth_norm = 1.0 - depth_norm

    orig_rgba = orig_image.convert("RGBA")
    final_image = orig_rgba.copy()

    
    band_edges = np.linspace(0, 1, num_bands + 1)

   
    for i in range(num_bands):
        band_min = band_edges[i]
        band_max = band_edges[i+1]
        # Use the midpoint of the band to determine the blur strength.
        mid = (band_min + band_max) / 2.0
        # For example, if mid is lower (i.e. farther away) we want more blur.
        blur_radius = (1 - mid) * max_blur

        # Create a blurred version of the original image for this band.
        blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius))

        # Create a mask for pixels whose normalized depth is within this band.
        band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
        band_mask_pil = Image.fromarray(band_mask, mode="L")

        final_image = Image.composite(blurred_version, final_image, band_mask_pil)

    # Convert back to RGB and return.
    return final_image.convert("RGB")

def main():
    hf_token = os.environ.get("HF_ACCESS_TOKEN")
    if hf_token is None:
        raise RuntimeError("HF_ACCESS_TOKEN is not set. Please add it as a secret.")
        
    st.title("Custom Background Blur Demo")

    # 1. Upload an image
    uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

    if uploaded_file is not None:
        # 2. Open and display the original image
        image = Image.open(uploaded_file).convert("RGB")
        orig_width, orig_height = image.size
        st.image(image, caption="Original Image", use_container_width=True)

        st.write("---")
        st.subheader("Blur Settings")

        col1, col2 = st.columns(2)
        
        device = "cpu"
        #print(device)

        # added the tokens 
        model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True, use_auth_token=hf_token)
        torch.set_float32_matmul_precision(['high', 'highest'][0])
        model.to(device)
        model.eval()

        image_size = (512, 512)
        transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        image = image.convert("RGB")
        input_images = transform_image(image).unsqueeze(0).to(device)

        # Inference on pytorch
        with torch.no_grad():
            # Get the final output, apply sigmoid to obtain values in [0,1]
            preds = model(input_images)[-1].sigmoid().cpu()
        pred = preds[0].squeeze()
        
        # Applying threshold for a binary mask
        threshold = 0.5
        binary_mask = (pred > threshold).float()
        
        mask_pil = transforms.ToPILImage()(binary_mask)
        mask_pil = mask_pil.convert("L")  # Ensure it's in grayscale
        mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
        orig_width, orig_height = image.size
        
        mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)

        #blur_radius = 15  # adjust radius to control blur strength
        depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")

        #resized_image = image.resize((512, 512))
        results = depth_pipeline(image)
        
        #print(results)
        depth_map_image = results['depth']
        

        with col1:
            gauss_radius = st.slider("Gaussian Blur Radius", 0, 30, 10, key="gauss")
            #gaussian_blurred = image.filter(ImageFilter.GaussianBlur(gauss_radius))

            blurred_image = image.filter(ImageFilter.GaussianBlur(gauss_radius))  # background is blurred 
        
        
            # White (255) in mask_pil = from image1 (orig_image)
            # Black (0) in mask_pil = from image2 (blurred_image)
            final_image = Image.composite(image, blurred_image, mask_pil)
            st.image(
                final_image,
                caption=f"Gaussian Blur (radius={gauss_radius})",
                use_container_width=True
            )

        with col2:
            blur_max = st.slider("Lens Blur Radius", 0, 5, 1, key="lens")
            output_image = depth_based_blur(image, depth_map_image, max_blur=blur_max, num_bands=40, invert_depth=False)
            st.image(
                output_image,
                caption=f"Lens Blur (blur={blur_max})",
                use_container_width=True
            )

if __name__ == "__main__":
    main()