File size: 3,789 Bytes
71fa3fe
 
 
 
 
 
 
 
 
15ad333
 
 
71fa3fe
 
15ad333
 
 
 
71fa3fe
 
 
 
 
 
15ad333
71fa3fe
 
 
 
15ad333
71fa3fe
 
 
15ad333
 
 
 
 
 
71fa3fe
 
15ad333
 
 
 
71fa3fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47eddf9
 
71fa3fe
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from torchvision import transforms
from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation

def load_segmentation_model():
    model_name = "ZhengPeng7/BiRefNet"
    model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
    return model

def load_depth_model():
    model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
    processor = AutoProcessor.from_pretrained(model_name)
    model = AutoModelForDepthEstimation.from_pretrained(model_name)
    return processor, model

def process_segmentation_image(image):
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])
    input_tensor = transform(image).unsqueeze(0)
    return image, input_tensor

def process_depth_image(image, processor):
    image = image.resize((512, 512))
    inputs = processor(images=image, return_tensors="pt")
    return image, inputs

def segment_image(image, input_tensor, model):
    with torch.no_grad():
        outputs = model(input_tensor)
        output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
        mask = torch.sigmoid(output_tensor).squeeze().cpu().numpy()
        mask = (mask > 0.5).astype(np.uint8) * 255
    return mask

def estimate_depth(inputs, model):
    with torch.no_grad():
        outputs = model(**inputs)
    depth_map = outputs.predicted_depth.squeeze().cpu().numpy()
    return depth_map

def normalize_depth_map(depth_map):
    min_val = np.min(depth_map)
    max_val = np.max(depth_map)
    normalized_depth = (depth_map - min_val) / (max_val - min_val)
    return normalized_depth

def apply_blur(image, mask):
    mask_pil = Image.fromarray(mask).resize(image.size, Image.BILINEAR)
    blurred_background = image.filter(ImageFilter.GaussianBlur(15))
    final_image = Image.composite(image, blurred_background, mask_pil)
    return final_image

def apply_depth_based_blur(image, depth_map):
    normalized_depth = normalize_depth_map(depth_map)
    image = image.resize((512, 512))
    blurred_image = image.copy()
    for y in range(image.height):
        for x in range(image.width):
            depth_value = float(normalized_depth[y, x])
            blur_radius = max(0, depth_value * 20)
            cropped_region = image.crop((max(x-10, 0), max(y-10, 0), min(x+10, image.width), min(y+10, image.height)))
            blurred_region = cropped_region.filter(ImageFilter.GaussianBlur(blur_radius))
            blurred_image.paste(blurred_region, (max(x-10, 0), max(y-10, 0)))
    return blurred_image

def process_image_pipeline(image):
    segmentation_model = load_segmentation_model()
    depth_processor, depth_model = load_depth_model()
    
    _, input_tensor = process_segmentation_image(image)
    _, inputs = process_depth_image(image, depth_processor)
    
    segmentation_mask = segment_image(image, input_tensor, segmentation_model)
    depth_map = estimate_depth(inputs, depth_model)
    blurred_image = apply_depth_based_blur(image, depth_map)
    gaussian_blur_image = apply_blur(image, segmentation_mask)
    
    return Image.fromarray(segmentation_mask), blurred_image, gaussian_blur_image

iface = gr.Interface(
    fn=process_image_pipeline,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(label="Segmentation Mask"),
        gr.Image(label="Lens Blur Effect"),
        gr.Image(label="Gaussian Blur Effect")
    ],
    title="Segmentation and Image Effect Processing",
    description="Upload an image to get segmentation mask, lens blur effect, and Gaussian blur effect."
)

if __name__ == "__main__":
    iface.launch(share=True)