File size: 4,120 Bytes
2a8387c
 
951e3bb
61d3bf0
2a8387c
 
61d3bf0
2a8387c
61d3bf0
2a8387c
b9bd2e1
61d3bf0
2a8387c
 
 
61d3bf0
2a8387c
 
 
 
 
 
 
 
 
 
b9bd2e1
2a8387c
 
 
 
 
 
31ca554
2a8387c
31ca554
2a8387c
 
31ca554
2a8387c
31ca554
2a8387c
 
31ca554
 
 
 
 
 
 
 
 
b9bd2e1
5993699
 
 
2a8387c
 
5993699
 
 
 
 
 
 
2a8387c
5993699
 
 
a604886
 
5993699
 
 
a604886
61d3bf0
2a8387c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5993699
2a8387c
9055ebf
2a8387c
 
 
 
 
 
5993699
2a8387c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import torch
import gradio as gr
from PIL import Image, ImageFilter
import torchvision.transforms as transforms
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
import numpy as np
import io

# Load Models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

HF_model_name = 'BiRefNet'
birefnet = AutoModelForImageSegmentation.from_pretrained(f'zhengpeng7/{HF_model_name}', trust_remote_code=True).to(device).eval()
print('BiRefNet (Segmentation) is ready to use.')

depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device).eval()
print('DepthPro (Blur) is ready to use.')

# Combined Image Transform
transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Refine Foreground (Placeholder)
def refine_foreground(image, mask):
    return image # Implement your refinement logic here

# Segmentation Function
def segment_image(image):
    print("Starting segmentation with background blur...")
    input_image = transform_image(image).unsqueeze(0).to(device)
    print("Input image tensor shape:", input_image.shape)
    with torch.no_grad():
        pred = birefnet(input_image)[-1].sigmoid().cpu()[0].squeeze()
    print("Prediction tensor shape:", pred.shape)
    mask = transforms.ToPILImage()(pred).resize(image.size)
    print("Mask PIL image size:", mask.size)
    image_masked = refine_foreground(image.copy(), mask)
    image_masked.putalpha(mask)

    # Apply Gaussian blur to the background
    blurred_background = image.copy()
    blurred_background.paste((0, 0, 0, 0), mask=mask)
    blurred_background = blurred_background.filter(ImageFilter.GaussianBlur(15))
    blurred_background.paste(image_masked, mask=mask)

    print("Segmentation with background blur completed.")
    return blurred_background

# Blur Function (Rewritten)
def apply_background_blur(image: Image):
    image = image.convert("RGB")
    inputs = depth_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = depth_model(**inputs)
        post_processed_output = depth_processor.post_process_depth_estimation(
            outputs, target_sizes=[(image.height, image.width)],
        )
    depth = post_processed_output[0]["predicted_depth"]
    depth_np = depth.detach().cpu().numpy().squeeze()
    depth_normalized = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min())
    blurred_image = image.copy()
    blur_strength = 20  # You can adjust this for overall blur strength
    blur_map = (depth_normalized * blur_strength).astype(int)
    for radius in range(1, blur_strength + 1):
        mask = (blur_map == radius)
        if np.any(mask):
            temp_image = image.copy()
            temp_image = temp_image.filter(ImageFilter.GaussianBlur(radius))
            blurred_image = Image.composite(temp_image, blurred_image, Image.fromarray((mask * 255).astype(np.uint8)))
    return blurred_image

# Process Image Function
def process_image(image, action):
    image = image.convert("RGB")
    if action == "Segmentation":
        return segment_image(image)
    elif action == "Blur":
        return apply_background_blur(image)
    elif action == "Both":
        return segment_image(image), apply_background_blur(image)
    else:
        return None

# Gradio Interface
def gradio_interface(image, action):
    result = process_image(image, action)
    if action == "Both":
        return result[0], result[1]
    else:
        return result, gr.Image(visible=False) # Return a hidden image when not needed.

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[gr.Image(type="pil", label="Upload Image"), gr.Dropdown(["Segmentation", "Blur", "Both"], label="Select Action")],
    outputs=[
        gr.Image(label="Output Image 1"),
        gr.Image(label="Output Image 2", visible=False)
    ],
    live=False,
)

interface.launch()