File size: 6,135 Bytes
2c0100c
e18cf09
 
a558a0b
e18cf09
 
a558a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c0100c
f5f8361
a558a0b
2c0100c
 
f5f8361
e18cf09
 
 
 
 
 
 
 
 
6812920
 
 
 
f5f8361
e18cf09
 
 
 
f5f8361
e18cf09
 
f5f8361
e18cf09
f5f8361
 
 
 
 
 
 
e18cf09
f5f8361
 
e18cf09
 
f5f8361
 
e18cf09
f5f8361
 
e18cf09
 
 
f5f8361
 
 
 
 
 
 
e18cf09
a558a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5f8361
 
a558a0b
e18cf09
a558a0b
 
e18cf09
f5f8361
a558a0b
 
 
 
 
 
4bdbf01
a558a0b
 
 
e18cf09
a558a0b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
import torch
from torch import nn
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
import numpy as np
from PIL import Image, ImageFilter
import os, hashlib
from huggingface_hub import snapshot_download

'''for Lens Blur'''
# Global: load model & processor once
MODEL_REPO = "apple/DepthPro-hf"
CACHE_DIR = "./cache"  # cache folder for model files
EXPECTED_SHA256 = "9c6811e3165485b9a94a204329860cb333a79877e757eb795a179a4ea34bbcf7"  # expected hash​:contentReference[oaicite:7]{index=7}

# Download model repository (if not cached) and verify SHA-256
snapshot_path = snapshot_download(repo_id=MODEL_REPO, cache_dir=CACHE_DIR)
model_file = os.path.join(snapshot_path, "model.safetensors")
# Compute SHA-256 of the model file
with open(model_file, "rb") as f:
    file_hash = hashlib.sha256(f.read()).hexdigest()
if file_hash != EXPECTED_SHA256:
    raise RuntimeError("Model file hash mismatch! Download may be corrupted.")
# Load model and processor (from local files, avoiding re-download)
model = DepthProForDepthEstimation.from_pretrained(snapshot_path)
processor = DepthProImageProcessorFast.from_pretrained(snapshot_path)
# Use GPU if available for speed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()

# Define the simple greeting function
def greet_test(name):
    return "Hello " + name + "!!"

# Define the Gaussian blur + segmentation function
def gauss_blur(image, sigma):
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    
    # Ensure image is a PIL Image
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    
    # Load models
    image_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
    model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
    model.to(device)
    
    # Run inference on image
    inputs = image_processor(images=image, return_tensors="pt").to(device)
    outputs = model(**inputs)
    logits = outputs.logits

    # Resize output to match input image dimensions
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],  # H x W
        mode='bilinear',
        align_corners=False
    )

    # Get label masks
    labels = upsampled_logits.argmax(dim=1)[0]
    labels_viz = labels.cpu().numpy()

    # Create foreground mask
    foreground_mask = (labels_viz != 0).astype(np.uint8)

    # Apply Gaussian blur
    image_np = np.array(image)
    blurred_image = Image.fromarray(image_np).filter(ImageFilter.GaussianBlur(radius=sigma))
    blurred_image_np = np.array(blurred_image)

    # Combine blurred background with original foreground
    result_image_np = (
        image_np * foreground_mask[:, :, None] + 
        blurred_image_np * (1 - foreground_mask[:, :, None])
    )

    return Image.fromarray(result_image_np.astype(np.uint8))

def lens_blur(image: Image.Image) -> Image.Image:
    """Apply depth-dependent lens blur to the input PIL image using DepthPro model."""
    # 1. Preprocess input: resize (preserve aspect ratio up to 1536px) and prepare tensor
    orig_w, orig_h = image.size
    max_dim = max(orig_w, orig_h)
    if max_dim > 1536:  # limit size for model
        ratio = 1536.0 / max_dim
        new_size = (int(orig_w * ratio), int(orig_h * ratio))
        image_resized = image.resize(new_size, Image.LANCZOS)
    else:
        image_resized = image
    # Prepare model input
    inputs = processor(images=image_resized, return_tensors="pt")  # includes resizing to 1536x1536 internally
    inputs = {k: v.to(device) for k, v in inputs.items()}
    # 2. Inference: predict depth map
    with torch.no_grad():
        outputs = model(**inputs)
    # Post-process to get depth map at original image resolution
    depth_map = processor.post_process_depth_estimation(
        outputs, target_sizes=[(orig_h, orig_w)]
    )[0]["predicted_depth"]
    depth_map = depth_map.squeeze().cpu().float().numpy()  # H x W depth values
    # Normalize depth to [0,1]
    depth_min, depth_max = depth_map.min(), depth_map.max()
    if depth_max > depth_min:
        depth_norm = (depth_map - depth_min) / (depth_max - depth_min)
    else:
        depth_norm = depth_map * 0.0  # all pixels same depth
    # 3. Create blurred version of the original image
    blurred_image = image.filter(ImageFilter.GaussianBlur(radius=15))
    blurred_np = np.array(blurred_image, dtype=np.float32)
    original_np = np.array(image, dtype=np.float32)
    # Ensure depth mask has shape (H, W, 1) for broadcasting across color channels
    depth_mask = depth_norm.astype(np.float32)[..., None]
    # 4. Blend images: near (mask~0) -> original, far (mask~1) -> blurred
    blended_np = original_np * (1 - depth_mask) + blurred_np * depth_mask
    blended_np = blended_np.clip(0, 255).astype(np.uint8)
    result_image = Image.fromarray(blended_np)
    
    return result_image
    
# Build the Gradio app with Tabs
with gr.Blocks() as demo:
    gr.Markdown("# Gaussian Blur and Lens Blur Demo")

    with gr.Tab("Greeting (Basic Test)"):
        gr.Interface(fn=greet_test, inputs="text", outputs="text")

    with gr.Tab("Gaussian Blur on Foreground"):
        gr.Interface(fn=gauss_blur, inputs=["image", "number"], outputs="image",
                     title="Gaussian Blur",
                     description="Apply Gaussian blur to the background of the image while keeping the foreground sharp. Adjust the sigma value to control the blur intensity.",
                     )
        
    with gr.Tab("Lens Blur"):
        gr.Interface(fn=lens_blur, inputs=gr.Image(type="pil"), outputs="image",
                     title="Lens Blur",
                     description="Apply depth-dependent lens blur to the image using the Apple DepthPro model. The blur intensity varies based on the depth of each pixel.",
                     )

demo.launch(share=True)  # Uncomment to enable sharing