pavank007 commited on
Commit
848f3c0
·
verified ·
1 Parent(s): ab2ada6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -166
app.py CHANGED
@@ -1,214 +1,222 @@
1
  import gradio as gr
2
  import torch
3
- import cv2
4
  import numpy as np
 
5
  from PIL import Image
6
- import requests
7
- from io import BytesIO
8
- from transformers import AutoFeatureExtractor, AutoModelForSemanticSegmentation
9
- from transformers import AutoImageProcessor, AutoModelForDepthEstimation
10
- import torch.nn.functional as F
11
-
12
- # Define device
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  # Load segmentation model
16
- segmentation_model_name = "facebook/mask2former-swin-tiny-coco-instance"
17
- seg_feature_extractor = AutoFeatureExtractor.from_pretrained(segmentation_model_name)
18
- seg_model = AutoModelForSemanticSegmentation.from_pretrained(segmentation_model_name).to(device)
19
 
20
  # Load depth estimation model
21
- depth_model_name = "intel-isl/MiDaS-small"
22
- depth_processor = AutoImageProcessor.from_pretrained(depth_model_name)
23
- depth_model = AutoModelForDepthEstimation.from_pretrained(depth_model_name).to(device)
24
 
25
- def apply_segmentation(input_image):
26
- # Convert to PIL Image if needed
27
- if not isinstance(input_image, Image.Image):
28
- input_image = Image.fromarray(input_image)
 
 
 
 
 
 
29
 
30
- # Resize to 512x512 for consistent processing
31
- input_image = input_image.resize((512, 512))
 
32
 
33
- # Prepare image for the model
34
- inputs = seg_feature_extractor(images=input_image, return_tensors="pt").to(device)
 
 
 
35
 
36
- # Forward pass
37
- with torch.no_grad():
38
- outputs = seg_model(**inputs)
39
 
40
- # Process output to get binary mask (foreground=1, background=0)
41
- logits = outputs.logits
42
- predicted_mask = torch.argmax(logits, dim=1)
43
 
44
- # Convert to numpy for processing
45
- mask = predicted_mask[0].cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Focus on person class (typically class 0 or 1 depending on the model)
48
- mask = (mask > 0).astype(np.uint8) * 255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- return np.array(input_image), mask
51
 
52
- def apply_depth_estimation(input_image):
53
- # Convert to PIL Image if needed
54
- if not isinstance(input_image, Image.Image):
55
- input_image = Image.fromarray(input_image)
 
 
 
 
56
 
57
- # Resize to 512x512 for consistent processing
58
- input_image = input_image.resize((512, 512))
59
 
60
- # Prepare image for the model
61
- inputs = depth_processor(images=input_image, return_tensors="pt").to(device)
 
62
 
63
- # Forward pass
 
 
 
 
 
64
  with torch.no_grad():
65
  outputs = depth_model(**inputs)
 
66
 
67
- # Process depth map
68
- depth_map = outputs.predicted_depth
69
- depth_map = torch.nn.functional.interpolate(
70
- depth_map.unsqueeze(1),
71
- size=(512, 512),
72
  mode="bicubic",
73
  align_corners=False,
74
- ).squeeze()
75
-
76
- # Normalize depth map to 0-1 range
77
- depth_min = torch.min(depth_map)
78
- depth_max = torch.max(depth_map)
79
- depth_map = (depth_map - depth_min) / (depth_max - depth_min)
80
-
81
- # Convert to numpy
82
- depth_map = depth_map.cpu().numpy()
83
-
84
- # Convert depth to heatmap for visualization
85
- depth_map_vis = (depth_map * 255).astype(np.uint8)
86
- depth_map_vis = cv2.applyColorMap(depth_map_vis, cv2.COLORMAP_INFERNO)
87
-
88
- return np.array(input_image), depth_map, depth_map_vis
89
-
90
- def apply_gaussian_blur(image, mask, sigma=15):
91
- # Make a copy of the image
92
- result = image.copy()
93
-
94
- # Ensure mask is binary (0 or 1)
95
- if mask.max() > 1:
96
- mask = mask / 255.0
97
-
98
- # Expand mask to 3 channels if needed
99
- if len(mask.shape) == 2:
100
- mask = np.expand_dims(mask, axis=2)
101
- mask = np.repeat(mask, 3, axis=2)
102
-
103
- # Blur the entire image
104
- blurred = cv2.GaussianBlur(image, (0, 0), sigma)
105
-
106
- # Combine original image (foreground) with blurred image (background) using the mask
107
- result = image * mask + blurred * (1 - mask)
108
 
109
- return result.astype(np.uint8)
110
-
111
- def apply_depth_blur(image, depth_map, max_sigma=30):
112
- # Make a copy of the image
113
- result = np.zeros_like(image)
114
 
115
- # Ensure depth map values are between 0-1
116
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
117
 
118
- # Apply variable blur based on depth
119
- for sigma in range(1, max_sigma + 1):
120
- # Create a mask for this depth level
121
- depth_mask = ((depth_map >= (sigma - 1) / max_sigma) &
122
- (depth_map < sigma / max_sigma)).astype(np.float32)
 
 
 
 
 
 
123
 
124
- # Expand mask to 3 channels if needed
125
- if len(depth_mask.shape) == 2:
126
- depth_mask = np.expand_dims(depth_mask, axis=2)
127
- depth_mask = np.repeat(depth_mask, 3, axis=2)
128
 
129
- # Apply blur with current sigma
130
- current_blur = cv2.GaussianBlur(image, (0, 0), sigma)
131
 
132
- # Add to result
133
- result += (current_blur * depth_mask).astype(np.uint8)
134
-
135
- # Handle remaining pixels (if any)
136
- remaining_mask = (depth_map >= 1.0).astype(np.float32)
137
- if len(remaining_mask.shape) == 2:
138
- remaining_mask = np.expand_dims(remaining_mask, axis=2)
139
- remaining_mask = np.repeat(remaining_mask, 3, axis=2)
140
-
141
- max_blur = cv2.GaussianBlur(image, (0, 0), max_sigma)
142
- result += (max_blur * remaining_mask).astype(np.uint8)
143
-
144
- return result
145
-
146
- def process_image(input_image, blur_type, blur_strength):
147
- # Convert to numpy array if needed
148
- if isinstance(input_image, str):
149
- # Load from URL if it's a string
150
- response = requests.get(input_image)
151
- input_image = Image.open(BytesIO(response.content))
152
-
153
- # Resize to 512x512 for consistent processing
154
- input_image = Image.fromarray(input_image).resize((512, 512))
155
- input_image_np = np.array(input_image)
156
-
157
- # Process based on selected blur type
158
- if blur_type == "Gaussian Background Blur":
159
- # Apply segmentation
160
- _, mask = apply_segmentation(input_image)
161
 
162
- # Apply Gaussian blur with chosen strength
163
- result = apply_gaussian_blur(input_image_np, mask, sigma=blur_strength)
164
 
165
- return input_image_np, result
166
-
167
- elif blur_type == "Depth-based Lens Blur":
168
- # Apply depth estimation
169
- _, depth_map, depth_vis = apply_depth_estimation(input_image)
170
 
171
- # Apply depth-based blur with chosen max strength
172
- result = apply_depth_blur(input_image_np, depth_map, max_sigma=blur_strength)
173
 
174
- return input_image_np, result
 
 
 
 
 
 
 
 
 
175
 
176
- # Create Gradio Interface
177
- with gr.Blocks(title="Image Blur Effects Demo") as app:
178
- gr.Markdown("# Image Blur Effects Demo")
179
- gr.Markdown("Upload an image to apply different blur effects using deep learning models")
180
 
181
  with gr.Row():
182
- input_image = gr.Image(label="Input Image", type="numpy")
183
- output_image = gr.Image(label="Output Image", type="numpy")
184
-
185
- with gr.Row():
186
- blur_type = gr.Radio(
187
- choices=["Gaussian Background Blur", "Depth-based Lens Blur"],
188
- label="Blur Effect Type",
189
- value="Gaussian Background Blur"
190
- )
191
- blur_strength = gr.Slider(
192
- minimum=1, maximum=50, value=15, step=1,
193
- label="Blur Strength"
194
- )
195
-
196
- submit_button = gr.Button("Apply Effect")
197
-
198
- submit_button.click(
 
 
199
  fn=process_image,
200
- inputs=[input_image, blur_type, blur_strength],
201
- outputs=[input_image, output_image]
202
  )
203
 
204
  gr.Markdown("""
205
  ## How it works
206
 
207
- 1. **Gaussian Background Blur**: Uses a segmentation model to detect the foreground object and applies blur to the background
208
- 2. **Depth-based Lens Blur**: Uses a depth estimation model to create a variable blur effect where objects further away are more blurred
209
 
210
- Both models are from Hugging Face Transformers library.
211
  """)
212
 
213
- # Launch the app
214
- app.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
+ import cv2
5
  from PIL import Image
6
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
7
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
 
 
 
 
10
 
11
  # Load segmentation model
12
+ seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-coco-instance")
13
+ seg_model = AutoModelForSemanticSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance")
 
14
 
15
  # Load depth estimation model
16
+ depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
17
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
 
18
 
19
+ def apply_gaussian_blur(image, mask, sigma=15):
20
+ """Apply Gaussian blur to the background of an image based on a mask."""
21
+ # Convert mask to binary (0 and 255)
22
+ if mask.max() <= 1.0:
23
+ binary_mask = (mask * 255).astype(np.uint8)
24
+ else:
25
+ binary_mask = mask.astype(np.uint8)
26
+
27
+ # Create a blurred version of the entire image
28
+ blurred = cv2.GaussianBlur(image, (0, 0), sigma)
29
 
30
+ # Resize mask to match image dimensions if needed
31
+ if binary_mask.shape[:2] != image.shape[:2]:
32
+ binary_mask = cv2.resize(binary_mask, (image.shape[1], image.shape[0]))
33
 
34
+ # Create a 3-channel mask if the input mask is single-channel
35
+ if len(binary_mask.shape) == 2:
36
+ mask_3ch = np.stack([binary_mask, binary_mask, binary_mask], axis=2)
37
+ else:
38
+ mask_3ch = binary_mask
39
 
40
+ # Normalize mask to range [0, 1]
41
+ mask_3ch = mask_3ch / 255.0
 
42
 
43
+ # Combine original image (foreground) with blurred image (background) using the mask
44
+ result = image * mask_3ch + blurred * (1 - mask_3ch)
 
45
 
46
+ return result.astype(np.uint8)
47
+
48
+ def apply_depth_blur(image, depth_map, max_sigma=25):
49
+ """Apply variable Gaussian blur based on depth map."""
50
+ # Normalize depth map to range [0, 1]
51
+ if depth_map.max() > 1.0:
52
+ depth_norm = depth_map / depth_map.max()
53
+ else:
54
+ depth_norm = depth_map
55
+
56
+ # Resize depth map to match image dimensions if needed
57
+ if depth_norm.shape[:2] != image.shape[:2]:
58
+ depth_norm = cv2.resize(depth_norm, (image.shape[1], image.shape[0]))
59
+
60
+ # Create output image
61
+ result = np.zeros_like(image)
62
 
63
+ # Apply different blur levels based on depth
64
+ for sigma in range(1, int(max_sigma) + 1, 2):
65
+ # Create a mask for pixels at this depth level
66
+ lower_bound = (sigma - 1) / max_sigma
67
+ upper_bound = (sigma + 1) / max_sigma
68
+ mask = np.logical_and(depth_norm >= lower_bound, depth_norm <= upper_bound).astype(np.float32)
69
+
70
+ # Skip if no pixels at this depth
71
+ if not np.any(mask):
72
+ continue
73
+
74
+ # Blur the image with current sigma
75
+ blurred = cv2.GaussianBlur(image, (0, 0), sigma)
76
+
77
+ # Create a 3-channel mask if the input mask is single-channel
78
+ if len(mask.shape) == 2:
79
+ mask_3ch = np.stack([mask, mask, mask], axis=2)
80
+ else:
81
+ mask_3ch = mask
82
+
83
+ # Add the blurred pixels at this depth to the result
84
+ result += (blurred * mask_3ch).astype(np.uint8)
85
+
86
+ # Fill in any missing pixels (where sum of all masks < 1)
87
+ total_mask = np.zeros_like(depth_norm)
88
+ for sigma in range(1, int(max_sigma) + 1, 2):
89
+ lower_bound = (sigma - 1) / max_sigma
90
+ upper_bound = (sigma + 1) / max_sigma
91
+ mask = np.logical_and(depth_norm >= lower_bound, depth_norm <= upper_bound).astype(np.float32)
92
+ total_mask += mask
93
+
94
+ missing_mask = (total_mask < 0.5).astype(np.float32)
95
+ if np.any(missing_mask):
96
+ missing_mask_3ch = np.stack([missing_mask, missing_mask, missing_mask], axis=2)
97
+ result += (image * missing_mask_3ch).astype(np.uint8)
98
 
99
+ return result
100
 
101
+ def get_segmentation_mask(image_pil):
102
+ """Get segmentation mask for person class from an image."""
103
+ # Process the image with the segmentation model
104
+ inputs = seg_processor(images=image_pil, return_tensors="pt")
105
+ outputs = seg_model(**inputs)
106
+
107
+ # Get the predicted segmentation mask
108
+ predicted_mask = seg_processor.post_process_semantic_segmentation(outputs, target_sizes=[image_pil.size[::-1]])[0]
109
 
110
+ # Convert the mask to a numpy array
111
+ mask_np = predicted_mask.cpu().numpy()
112
 
113
+ # Get mask for person class (typically class 0 in COCO dataset)
114
+ person_mask = np.zeros_like(mask_np)
115
+ person_mask[mask_np == 0] = 1 # Assuming person is class 0
116
 
117
+ return person_mask
118
+
119
+ def get_depth_map(image_pil):
120
+ """Get depth map from an image."""
121
+ # Process the image with the depth estimation model
122
+ inputs = depth_processor(images=image_pil, return_tensors="pt")
123
  with torch.no_grad():
124
  outputs = depth_model(**inputs)
125
+ predicted_depth = outputs.predicted_depth
126
 
127
+ # Interpolate to original size
128
+ prediction = torch.nn.functional.interpolate(
129
+ predicted_depth.unsqueeze(1),
130
+ size=image_pil.size[::-1],
 
131
  mode="bicubic",
132
  align_corners=False,
133
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ # Convert to numpy array
136
+ depth_map = prediction.squeeze().cpu().numpy()
 
 
 
137
 
138
+ # Normalize depth map
139
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
140
 
141
+ return depth_map
142
+
143
+ def process_image(input_image, blur_sigma=15, depth_blur_sigma=25):
144
+ """Main function to process the input image."""
145
+ try:
146
+ # Convert to PIL Image if needed
147
+ if isinstance(input_image, np.ndarray):
148
+ pil_image = Image.fromarray(input_image)
149
+ else:
150
+ pil_image = input_image
151
+ input_image = np.array(pil_image)
152
 
153
+ # Get segmentation mask
154
+ seg_mask = get_segmentation_mask(pil_image)
 
 
155
 
156
+ # Get depth map
157
+ depth_map = get_depth_map(pil_image)
158
 
159
+ # Apply gaussian blur to background
160
+ gaussian_result = apply_gaussian_blur(input_image, seg_mask, sigma=blur_sigma)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ # Apply depth-based blur
163
+ depth_result = apply_depth_blur(input_image, depth_map, max_sigma=depth_blur_sigma)
164
 
165
+ # Display depth map as an image
166
+ depth_visualization = (depth_map * 255).astype(np.uint8)
167
+ depth_visualization = cv2.applyColorMap(depth_visualization, cv2.COLORMAP_INFERNO)
 
 
168
 
169
+ # Display segmentation mask
170
+ seg_visualization = (seg_mask * 255).astype(np.uint8)
171
 
172
+ return [
173
+ input_image,
174
+ seg_visualization,
175
+ gaussian_result,
176
+ depth_visualization,
177
+ depth_result
178
+ ]
179
+ except Exception as e:
180
+ print(f"Error processing image: {e}")
181
+ return [None, None, None, None, None]
182
 
183
+ # Create Gradio interface
184
+ with gr.Blocks(title="Image Blur Effects with Segmentation and Depth Estimation") as demo:
185
+ gr.Markdown("# Image Blur Effects App")
186
+ gr.Markdown("This app demonstrates two types of blur effects: background blur using segmentation and depth-based lens blur.")
187
 
188
  with gr.Row():
189
+ with gr.Column():
190
+ input_image = gr.Image(label="Upload an image", type="pil")
191
+ blur_sigma = gr.Slider(minimum=1, maximum=50, value=15, step=1, label="Background Blur Intensity")
192
+ depth_blur_sigma = gr.Slider(minimum=1, maximum=50, value=25, step=1, label="Depth Blur Max Intensity")
193
+ process_btn = gr.Button("Process Image")
194
+
195
+ with gr.Column():
196
+ with gr.Tab("Original Image"):
197
+ output_original = gr.Image(label="Original Image")
198
+ with gr.Tab("Segmentation Mask"):
199
+ output_segmentation = gr.Image(label="Segmentation Mask")
200
+ with gr.Tab("Background Blur"):
201
+ output_gaussian = gr.Image(label="Background Blur Result")
202
+ with gr.Tab("Depth Map"):
203
+ output_depth = gr.Image(label="Depth Map")
204
+ with gr.Tab("Depth-based Lens Blur"):
205
+ output_depth_blur = gr.Image(label="Depth-based Lens Blur Result")
206
+
207
+ process_btn.click(
208
  fn=process_image,
209
+ inputs=[input_image, blur_sigma, depth_blur_sigma],
210
+ outputs=[output_original, output_segmentation, output_gaussian, output_depth, output_depth_blur]
211
  )
212
 
213
  gr.Markdown("""
214
  ## How it works
215
 
216
+ 1. **Background Blur**: Uses a segmentation model to identify foreground objects and blurs only the background
217
+ 2. **Depth-based Lens Blur**: Uses a depth estimation model to apply variable blur based on estimated distance
218
 
219
+ Try uploading a photo of a person or object against a background to see the effects!
220
  """)
221
 
222
+ demo.launch()