pavank007 commited on
Commit
a4deed9
·
verified ·
1 Parent(s): 9d1baca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image, ImageFilter
5
+ import cv2
6
+ from transformers import pipeline
7
+
8
+ # Set device to GPU if available
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ print(f"Using device: {device}")
11
+
12
+ # Load models only once at startup to improve performance
13
+ segmentation_model = "facebook/sam-vit-huge"
14
+ depth_model = "depth-anything/Depth-Anything-V2-Small-hf"
15
+
16
+ # Initialize pipelines
17
+ segmentation_pipe = pipeline("image-segmentation", model=segmentation_model)
18
+ depth_pipe = pipeline("depth-estimation", model=depth_model)
19
+
20
+ def get_segmentation_mask(input_image):
21
+ """Get segmentation mask using the pre-loaded segmentation pipeline"""
22
+ # Resize image to 512x512 for consistent processing
23
+ input_image = input_image.resize((512, 512)).convert('RGB')
24
+
25
+ # Get the segmentation result
26
+ result = segmentation_pipe(input_image)
27
+
28
+ # Extract the first mask (assuming it's the most prominent object)
29
+ if len(result) > 0:
30
+ # For SAM-like models that return multiple masks
31
+ mask = result[0]['mask']
32
+ mask = np.array(mask) * 255 # Scale to [0, 255]
33
+ else:
34
+ # Fallback - create empty mask
35
+ mask = np.zeros((512, 512), dtype=np.uint8)
36
+
37
+ # Convert to PIL Image
38
+ mask_img = Image.fromarray(mask.astype(np.uint8))
39
+
40
+ return mask_img, input_image
41
+
42
+ def apply_background_blur(original_image, mask_image, sigma=15):
43
+ """Apply Gaussian blur to the background using a segmentation mask"""
44
+ # Ensure mask is binary (0 for background, 255 for foreground)
45
+ mask_array = np.array(mask_image)
46
+ _, binary_mask = cv2.threshold(mask_array, 127, 255, cv2.THRESH_BINARY)
47
+ mask_img = Image.fromarray(binary_mask)
48
+
49
+ # Create a blurred version of the original image
50
+ blurred_img = original_image.filter(ImageFilter.GaussianBlur(radius=sigma))
51
+
52
+ # Convert images to numpy arrays for easier manipulation
53
+ original_array = np.array(original_image)
54
+ blurred_array = np.array(blurred_img)
55
+ mask_array = np.array(mask_img)
56
+
57
+ # Create the composite image: foreground from original, background from blurred
58
+ result_array = np.zeros_like(original_array)
59
+
60
+ # Where mask is white (255), use original image; where mask is black (0), use blurred image
61
+ for c in range(3): # For each color channel (RGB)
62
+ result_array[:, :, c] = np.where(mask_array == 255,
63
+ original_array[:, :, c],
64
+ blurred_array[:, :, c])
65
+
66
+ # Convert back to PIL Image
67
+ result_img = Image.fromarray(result_array)
68
+
69
+ return result_img
70
+
71
+ def get_depth_map(input_image):
72
+ """Get depth map using the pre-loaded depth estimation pipeline"""
73
+ # Ensure image is in RGB format and resized to 512x512
74
+ input_image = input_image.resize((512, 512)).convert('RGB')
75
+
76
+ # Get the depth map
77
+ result = depth_pipe(input_image)
78
+ depth_map = result["depth"]
79
+
80
+ # Convert to numpy array for further processing
81
+ depth_array = np.array(depth_map)
82
+
83
+ return depth_map, depth_array
84
+
85
+ def apply_depth_based_blur(original_image, depth_array, max_blur=30):
86
+ """Apply variable Gaussian blur based on depth"""
87
+ # Convert depth array to proper format if needed
88
+ if len(depth_array.shape) == 3 and depth_array.shape[2] > 1:
89
+ # If depth map has multiple channels, convert to grayscale
90
+ depth_array = np.mean(depth_array, axis=2)
91
+
92
+ # Normalize depth values to range [0, 1]
93
+ depth_min = depth_array.min()
94
+ depth_max = depth_array.max()
95
+ normalized_depth = (depth_array - depth_min) / (depth_max - depth_min)
96
+
97
+ # Create a series of increasingly blurred versions of the image
98
+ blurred_images = []
99
+ for blur_amount in range(max_blur + 1):
100
+ blurred_images.append(original_image.filter(ImageFilter.GaussianBlur(radius=blur_amount)))
101
+
102
+ # Convert to numpy arrays for easier processing
103
+ original_array = np.array(original_image)
104
+ result_array = np.zeros_like(original_array)
105
+
106
+ # For each pixel, determine the blur level based on depth
107
+ height, width = normalized_depth.shape
108
+ for y in range(height):
109
+ for x in range(width):
110
+ # Calculate blur radius proportional to depth
111
+ # Higher normalized_depth = farther object = more blur
112
+ blur_radius = int(normalized_depth[y, x] * max_blur)
113
+ result_array[y, x] = np.array(blurred_images[blur_radius])[y, x]
114
+
115
+ return Image.fromarray(result_array)
116
+
117
+ def process_image(input_image, blur_sigma=15, max_depth_blur=30):
118
+ """Main function to process the image through all effects"""
119
+ if input_image is None:
120
+ return None, None, None, None
121
+
122
+ # Resize input image for consistent processing
123
+ input_image = Image.fromarray(input_image).convert('RGB')
124
+ input_image = input_image.resize((512, 512))
125
+
126
+ # Step 1: Get segmentation mask
127
+ mask, _ = get_segmentation_mask(input_image)
128
+
129
+ # Step 2: Apply background blur
130
+ blurred_background = apply_background_blur(input_image, mask, sigma=blur_sigma)
131
+
132
+ # Step 3: Get depth map
133
+ depth_map, depth_array = get_depth_map(input_image)
134
+
135
+ # Step 4: Apply depth-based blur
136
+ depth_blur = apply_depth_based_blur(input_image, depth_array, max_blur=max_depth_blur)
137
+
138
+ # Convert all PIL images to numpy arrays for Gradio
139
+ input_np = np.array(input_image)
140
+ mask_np = np.array(mask)
141
+ blurred_np = np.array(blurred_background)
142
+ depth_map_np = np.array(depth_map)
143
+ depth_blur_np = np.array(depth_blur)
144
+
145
+ return input_np, mask_np, blurred_np, depth_map_np, depth_blur_np
146
+
147
+ # Create Gradio Interface
148
+ with gr.Blocks(title="Image Blur Effects - EEE 515 Assignment 3") as demo:
149
+ gr.Markdown("# Image Blur Effects App")
150
+ gr.Markdown("Upload an image to apply segmentation-based blur and depth-based lens blur effects")
151
+
152
+ with gr.Row():
153
+ input_image = gr.Image(label="Upload Image", type="numpy")
154
+
155
+ with gr.Row():
156
+ blur_sigma = gr.Slider(minimum=1, maximum=30, value=15, step=1, label="Background Blur Strength (σ)")
157
+ depth_blur_max = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Max Depth Blur Strength")
158
+
159
+ with gr.Row():
160
+ process_btn = gr.Button("Process Image")
161
+
162
+ with gr.Tab("Segmentation Results"):
163
+ with gr.Row():
164
+ original_output = gr.Image(label="Original Image", type="numpy")
165
+ mask_output = gr.Image(label="Segmentation Mask", type="numpy")
166
+ with gr.Row():
167
+ blurred_output = gr.Image(label="Background Blur Effect", type="numpy")
168
+
169
+ with gr.Tab("Depth Results"):
170
+ with gr.Row():
171
+ depth_map_output = gr.Image(label="Depth Map", type="numpy")
172
+ depth_blur_output = gr.Image(label="Depth-Based Lens Blur", type="numpy")
173
+
174
+ process_btn.click(
175
+ fn=process_image,
176
+ inputs=[input_image, blur_sigma, depth_blur_max],
177
+ outputs=[original_output, mask_output, blurred_output, depth_map_output, depth_blur_output]
178
+ )
179
+
180
+ gr.Markdown("## How it works")
181
+ gr.Markdown("""
182
+ 1. **Segmentation-Based Blur**: Uses a segmentation model to identify the foreground object,
183
+ then applies Gaussian blur only to the background.
184
+
185
+ 2. **Depth-Based Lens Blur**: Uses a monocular depth estimation model to create a depth map,
186
+ then applies varying levels of blur based on the estimated depth.
187
+ """)
188
+
189
+ # Launch the app
190
+ demo.launch()