File size: 3,298 Bytes
8f8907e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f48bfcf
 
 
8f8907e
f48bfcf
8f8907e
 
 
 
f48bfcf
8f8907e
f48bfcf
 
8f8907e
f48bfcf
8f8907e
f48bfcf
8f8907e
f48bfcf
8f8907e
 
 
f48bfcf
 
8f8907e
 
f48bfcf
 
 
8f8907e
f48bfcf
 
8f8907e
f48bfcf
 
 
8f8907e
 
f48bfcf
 
 
 
8f8907e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import cv2
from transformers import (
    SegformerFeatureExtractor, SegformerForSemanticSegmentation,
    DPTFeatureExtractor, DPTForDepthEstimation
)

# Load models
seg_model_name = "nvidia/segformer-b1-finetuned-ade-512-512"
depth_model_name = "Intel/dpt-hybrid-midas"

seg_extractor = SegformerFeatureExtractor.from_pretrained(seg_model_name)
seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_name)
depth_extractor = DPTFeatureExtractor.from_pretrained(depth_model_name)
depth_model = DPTForDepthEstimation.from_pretrained(depth_model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seg_model.to(device)
depth_model.to(device)

def process_image(image_pil):
    image = ImageOps.exif_transpose(image_pil).resize((512, 512)).convert("RGB")
    
    # ---- Segmentation ----
    seg_inputs = seg_extractor(images=image, return_tensors="pt", do_resize=True, do_normalize=True)
    with torch.no_grad():
        seg_output = seg_model(**seg_inputs.to(device)).logits
    seg_mask = torch.argmax(seg_output, dim=1)[0].cpu().numpy()
    binary_mask = np.where(seg_mask > 0, 255, 0).astype(np.uint8)
    foreground_mask = Image.fromarray(binary_mask).convert("L")

    # ---- Blur Background ----
    image_rgba = image.convert("RGBA")
    blurred = image.filter(ImageFilter.GaussianBlur(15)).convert("RGBA")
    composite_blur = Image.composite(image_rgba, blurred, foreground_mask)

    # ---- Depth ----
    image_np = np.array(image)
    depth_inputs = depth_extractor(images=image_np, return_tensors="pt")
    with torch.no_grad():
        depth_output = depth_model(**depth_inputs.to(device))
    predicted_depth = depth_output.predicted_depth.squeeze().cpu().numpy()
    normalized_depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())

    # ---- Depth-Based Blur ----
    image_np = np.array(image).astype(np.float32)
    resized_depth = cv2.resize(normalized_depth, (image_np.shape[1], image_np.shape[0]))
    inverted_depth = 1.0 - resized_depth
    blur_levels = 4
    blurred_variants = []
    for i in range(blur_levels):
        sigma = i * 3
        blurred = cv2.GaussianBlur(image_np, (15, 15), sigmaX=sigma, sigmaY=sigma) if sigma > 0 else image_np.copy()
        blurred_variants.append(blurred)

    blur_indices = (inverted_depth * (blur_levels - 1)).astype(np.uint8)
    final_blur = np.zeros_like(image_np)
    for i in range(blur_levels):
        mask = (blur_indices == i)
        for c in range(3):
            final_blur[:, :, c][mask] = blurred_variants[i][:, :, c][mask]
    lens_blur_pil = Image.fromarray(np.clip(final_blur, 0, 255).astype(np.uint8))

    return image, composite_blur.convert("RGB"), lens_blur_pil


# Gradio Interface
gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(label="Original Image"),
        gr.Image(label="Segmented Gaussian Blur"),
        gr.Image(label="Depth-Based Lens Blur")
    ],
    title="Visual Effects Demo: Segmentation & Depth-Based Blur",
    description="Upload an image to see it segmented with background blur (like Zoom) and depth-based lens blur.",
    examples=[],
).launch()