File size: 3,905 Bytes
1acc6e4
a0bb102
 
 
 
 
 
1acc6e4
 
52428d7
 
a0bb102
 
 
 
e3a9a0b
 
1acc6e4
e3a9a0b
1acc6e4
52428d7
1acc6e4
a0bb102
 
 
 
 
1acc6e4
 
a0bb102
e3a9a0b
1acc6e4
 
 
 
 
a0bb102
1acc6e4
 
a0bb102
1acc6e4
a0bb102
e3a9a0b
1acc6e4
 
 
a0bb102
 
 
 
 
 
 
 
 
1acc6e4
a0bb102
1acc6e4
a0bb102
 
 
 
 
 
1acc6e4
a0bb102
 
1acc6e4
 
a0bb102
 
1acc6e4
 
 
 
 
a0bb102
1acc6e4
 
a0bb102
1acc6e4
a0bb102
 
 
 
1acc6e4
 
a0bb102
 
 
 
 
1acc6e4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from transformers import (
    SegformerImageProcessor,
    SegformerForSemanticSegmentation,
    DPTImageProcessor,
    DPTForDepthEstimation
)
from PIL import Image, ImageFilter
import numpy as np
import gradio as gr

# Suppress specific warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")

# Load pre-trained models and processors
seg_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
seg_model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")

def process_image(image):
    # Ensure image is in RGB
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    # Resize the image to 512x512
    image = image.resize((512, 512))
    
    # ------------------ Semantic Segmentation ------------------
    seg_inputs = seg_processor(images=image, return_tensors="pt")
    with torch.no_grad():
        seg_outputs = seg_model(**seg_inputs)
    seg_logits = seg_outputs.logits
    segmentation = torch.argmax(seg_logits, dim=1)[0].numpy()
    
    # Create binary mask for 'person' class (class index 12)
    person_class_index = 12
    binary_mask = (segmentation == person_class_index).astype(np.uint8) * 255
    binary_mask_image = Image.fromarray(binary_mask)
    
    # ------------------ Depth Estimation ------------------
    depth_inputs = depth_processor(images=image, return_tensors="pt")
    with torch.no_grad():
        depth_outputs = depth_model(**depth_inputs)
    predicted_depth = depth_outputs.predicted_depth[0].cpu().numpy()
    
    # Normalize the depth map for visualization
    min_depth = predicted_depth.min()
    max_depth = predicted_depth.max()
    normalized_depth = (predicted_depth - min_depth) / (max_depth - min_depth)
    depth_map_image = Image.fromarray((normalized_depth * 255).astype(np.uint8))
    
    # ------------------ Blurred Background Effect ------------------
    # Invert the depth map
    inverted_depth = 1 - normalized_depth
    inverted_depth = (inverted_depth - inverted_depth.min()) / (inverted_depth.max() - inverted_depth.min())
    
    # Resize and expand dimensions to match image channels
    depth_weight_resized = Image.fromarray((inverted_depth * 255).astype(np.uint8)).resize((512, 512))
    depth_weight_resized = np.array(depth_weight_resized) / 255.0
    depth_weight_resized = np.expand_dims(depth_weight_resized, axis=-1)
    
    # Apply Gaussian blur to the entire image
    blurred_image = image.filter(ImageFilter.GaussianBlur(radius=15))
    
    # Convert images to numpy arrays
    original_np = np.array(image).astype(np.float32)
    blurred_np = np.array(blurred_image).astype(np.float32)
    
    # Blend images based on the depth weight
    composite_np = (1 - depth_weight_resized) * original_np + depth_weight_resized * blurred_np
    composite_image = Image.fromarray(np.clip(composite_np, 0, 255).astype(np.uint8))
    
    return image, binary_mask_image, depth_map_image, composite_image

# Define Gradio interface using the updated API
interface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=[
        gr.Image(type="pil", label="Original Image"),
        gr.Image(type="pil", label="Segmentation Mask"),
        gr.Image(type="pil", label="Depth Map"),
        gr.Image(type="pil", label="Blurred Background Effect"),
    ],
    title="Semantic Segmentation and Depth Estimation",
    description="Upload an image to generate a segmentation mask, depth map, and a blurred background effect.",
    examples=[
        ["Selfie_1.jpg"],
        ["Selfie_2.jpg"]
    ]
)

# Launch the interface
if __name__ == "__main__":
    interface.launch()