File size: 3,131 Bytes
8f8907e
ac60056
c752dca
 
ac60056
 
 
 
8f8907e
ac60056
8f8907e
ac60056
 
8f8907e
ac60056
8f8907e
c752dca
ac60056
c752dca
ac60056
8f8907e
 
 
 
 
 
ac60056
 
 
 
 
 
8f8907e
ac60056
 
 
 
8f8907e
ac60056
 
 
8f8907e
ac60056
 
8f8907e
ac60056
 
 
 
 
8f8907e
ac60056
 
e4444ee
 
ac60056
e3f5c27
 
27b9d58
 
e3f5c27
a2fd20a
 
ac60056
f48bfcf
27b9d58
 
 
 
e3f5c27
8f8907e
c752dca
8f8907e
ac60056
8f8907e
ac60056
 
 
8f8907e
ac60056
 
c752dca
 
ac60056
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from PIL import Image, ImageFilter
import numpy as np
import torch
from transformers import (
    SegformerFeatureExtractor, SegformerForSemanticSegmentation,
    DPTFeatureExtractor, DPTForDepthEstimation
)
import cv2
import os, json

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# load segmentation model
seg_model_name = "nvidia/segformer-b1-finetuned-ade-512-512"
seg_fe = SegformerFeatureExtractor.from_pretrained(seg_model_name)
seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_name)

# load depth model
depth_model_name = "Intel/dpt-hybrid-midas"
depth_fe = DPTFeatureExtractor.from_pretrained(depth_model_name)
depth_model = DPTForDepthEstimation.from_pretrained(depth_model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seg_model.to(device)
depth_model.to(device)

def process_image(image: Image.Image):
    # 1) prep
    image = image.convert("RGB").resize((512,512))
    
    # 2) segmentation β†’ binary mask
    seg_inputs = seg_fe(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        seg_logits = seg_model(**seg_inputs).logits
    seg_map = torch.argmax(seg_logits, dim=1)[0].cpu().numpy()
    mask = (seg_map > 0).astype(np.uint8) * 255
    mask = Image.fromarray(mask).resize((512,512))

    # 3) gaussian-blur background
    bg_blur = image.filter(ImageFilter.GaussianBlur(15))
    output_blur = Image.composite(image, bg_blur, mask)

    # 4) depth estimation
    depth_inputs = depth_fe(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        depth_pred = depth_model(**depth_inputs).predicted_depth.squeeze().cpu().numpy()
    # normalize & resize
    dmin, dmax = depth_pred.min(), depth_pred.max()
    depth_norm = (depth_pred - dmin) / (dmax - dmin + 1e-8)
    depth_norm = cv2.resize(depth_norm, (512,512))

    # 5) vectorized depth-based blur
    img_np = np.array(image).astype(np.float32)
    # two extremes: no blur for near, heavy blur for far
    near_blur = img_np                          
    far_blur  = cv2.GaussianBlur(img_np, (81,81), 20)
    
    # invert so 0 -> near, 1 -> far
    inv_d = 1.0 - depth_norm
    alpha = inv_d[...,None]
    
    # now near (inv_dβ‰ˆ1) stays sharp, far (inv_dβ‰ˆ0) becomes far_blur
    combined = img_np * alpha + far_blur * (1.0 - alpha)
    lens_blur = Image.fromarray(np.clip(combined,0,255).astype(np.uint8))

    # 6) composite to keep foreground sharp
    mask_np = np.array(mask)
    inv_mask = (mask_np == 0).astype(np.uint8) * 255
    bg_mask  = Image.fromarray(inv_mask)
    final_lens_blur = Image.composite(lens_blur, image, mask=bg_mask)

iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=[
      gr.Image(type="pil", label="Original"),
      gr.Image(type="pil", label="Gaussian Blur"),
      gr.Image(type="pil", label="Depth-Based Lens Blur"),
    ],
    title="Image Blurring with CLAHE + Depth-Based Blur",
    description="Upload a selfie to see background blur and depth-based lens blur."
)

if __name__ == "__main__":
    iface.launch(share=True)