nsathya5 commited on
Commit
d460634
·
verified ·
1 Parent(s): 1cfe6b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -135
app.py CHANGED
@@ -1,83 +1,108 @@
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
6
  import matplotlib.pyplot as plt
7
- from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
8
- from transformers import DPTFeatureExtractor, DPTForDepthEstimation
9
-
10
- # Load a smaller segmentation model
11
- try:
12
- seg_processor = AutoFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
13
- seg_model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
14
- print("✓ Segmentation model loaded successfully")
15
- except Exception as e:
16
- print(f"! Error loading segmentation model: {e}")
17
- # Fallback implementation will be used
18
-
19
- # Load a smaller depth estimation model
20
- try:
21
- depth_processor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
22
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
23
- print("✓ Depth model loaded successfully")
24
- except Exception as e:
25
- print(f"! Error loading depth model: {e}")
26
- # Fallback implementation will be used
27
 
28
  def apply_gaussian_blur(image, sigma=15):
29
  """Apply Gaussian blur with specified sigma value."""
30
  return cv2.GaussianBlur(image, (0, 0), sigma)
31
 
32
- def get_foreground_mask(image):
33
- """Get foreground mask through simple methods if model fails."""
 
 
 
 
 
 
 
 
 
 
34
  try:
35
- # Try using the model first
36
- if seg_model is not None and seg_processor is not None:
37
- # Convert to RGB if needed
38
- if isinstance(image, np.ndarray):
39
- if len(image.shape) == 2:
40
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
41
- elif image.shape[2] == 4:
42
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
43
- pil_image = Image.fromarray(image)
44
- else:
45
- pil_image = image.convert('RGB')
46
-
47
- # Prepare image for the model
48
- inputs = seg_processor(images=pil_image, return_tensors="pt")
49
-
50
- # Run inference
51
- with torch.no_grad():
52
- outputs = seg_model(**inputs)
53
-
54
- # Process logits
55
- logits = outputs.logits
56
- upsampled_logits = torch.nn.functional.interpolate(
57
- logits,
58
- size=(image.shape[0], image.shape[1]),
59
- mode="bilinear",
60
- align_corners=False,
61
- )
62
-
63
- # Get mask (consider classes that are typically foreground, e.g., person)
64
- # In ADE20K dataset, person is class 12
65
- pred_seg = upsampled_logits.argmax(dim=1)[0]
66
- mask = (pred_seg == 12).float().cpu().numpy() # Person class
67
-
68
- # If person isn't detected, try other common foreground classes
69
- if mask.sum() < 100: # If almost no pixels were classified as person
70
- for cls in [13, 14, 15]: # Try other classes like vehicle, animal, etc.
71
- cls_mask = (pred_seg == cls).float().cpu().numpy()
72
- if cls_mask.sum() > mask.sum():
73
- mask = cls_mask
74
 
75
- return mask
 
 
 
76
 
77
  except Exception as e:
78
- print(f"Error in segmentation: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Fallback: Use a simple method - assume center of image is foreground
81
  h, w = image.shape[:2]
82
  y, x = np.ogrid[:h, :w]
83
  center_y, center_x = h / 2, w / 2
@@ -91,107 +116,120 @@ def get_foreground_mask(image):
91
 
92
  return mask
93
 
94
- def get_depth_map(image):
95
- """Get depth map from the image using model or fallback."""
96
- try:
97
- # Try using the model first
98
- if depth_model is not None and depth_processor is not None:
99
- # Convert to RGB if needed
100
- if isinstance(image, np.ndarray):
101
- if len(image.shape) == 2:
102
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
103
- elif image.shape[2] == 4:
104
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
105
- pil_image = Image.fromarray(image)
106
- else:
107
- pil_image = image.convert('RGB')
108
-
109
- # Prepare image for the model
110
- inputs = depth_processor(images=pil_image, return_tensors="pt")
111
-
112
- # Run inference
113
- with torch.no_grad():
114
- outputs = depth_model(**inputs)
115
-
116
- predicted_depth = outputs.predicted_depth
117
-
118
- # Interpolate to original size if needed
119
- depth_map = torch.nn.functional.interpolate(
120
- predicted_depth.unsqueeze(1),
121
- size=(image.shape[0], image.shape[1]),
122
- mode="bicubic",
123
- align_corners=False,
124
- ).squeeze().cpu().numpy()
125
-
126
- return depth_map
127
-
128
- except Exception as e:
129
- print(f"Error in depth estimation: {e}")
130
 
131
- # Fallback: Create a simple depth map based on distance from center
 
 
 
 
 
 
132
  h, w = image.shape[:2]
133
  y, x = np.ogrid[:h, :w]
134
  center_y, center_x = h / 2, w / 2
 
 
 
 
 
 
135
 
136
- # Create a radial gradient (closer to center = closer distance)
137
- depth = ((x - center_x)**2 / (w/2)**2 + (y - center_y)**2 / (h/2)**2)
138
- depth = np.clip(depth, 0, 1)
139
 
140
  return depth
141
 
142
- def process_image(input_image, blur_type="gaussian", blur_sigma=15):
143
- """Process the input image and return the results."""
 
144
  try:
145
- # Convert from Gradio format
146
- img = np.array(input_image)
 
 
 
 
 
147
  if img.ndim == 2: # Grayscale
148
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
149
  elif img.shape[2] == 4: # RGBA
150
  img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
151
 
152
- # 1. Get segmentation mask
153
- mask = get_foreground_mask(img)
154
- mask_vis = (mask * 255).astype(np.uint8)
155
- mask_color = cv2.applyColorMap(mask_vis, cv2.COLORMAP_JET)
156
 
157
- # 2. Get depth map
158
- depth_map = get_depth_map(img)
 
 
 
 
 
 
159
  depth_norm = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-8)
160
  depth_vis = plt.cm.viridis(depth_norm)[:, :, :3]
161
  depth_vis = (depth_vis * 255).astype(np.uint8)
162
 
163
  # Apply appropriate blur effect
164
- if blur_type == "gaussian":
165
- # Apply regular Gaussian blur
166
- blurred_img = apply_gaussian_blur(img, sigma=blur_sigma)
 
 
167
 
168
- # Combine original foreground with blurred background
169
- result = img.copy()
170
- for c in range(3): # For each color channel
 
171
  result[:,:,c] = mask * img[:,:,c] + (1-mask) * blurred_img[:,:,c]
172
 
173
- else: # depth-based blur
174
- # Apply depth-based blur
175
- result = img.copy()
 
 
 
176
 
177
- # Apply varying levels of blur based on depth
178
- # For simplicity, we'll use 5 levels of blur
179
- for i in range(1, 6):
180
- sigma = blur_sigma * i / 5 # Increasing sigma value
 
181
  level_blurred = apply_gaussian_blur(img, sigma=sigma)
182
 
183
  # Calculate weight for this blur level
184
- weight = (depth_norm > (i-1)/5) & (depth_norm <= i/5)
185
- weight = weight.astype(np.float32)
 
186
 
187
  # Apply this blur level where applicable
188
  for c in range(3):
189
  result[:,:,c] = np.where(weight, level_blurred[:,:,c], result[:,:,c])
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- # Convert to uint8
192
  result = result.astype(np.uint8)
 
193
 
194
- return result, mask_color, depth_vis
195
 
196
  except Exception as e:
197
  print(f"Error processing image: {e}")
@@ -231,7 +269,7 @@ with gr.Blocks(title="Image Blur Effects") as demo:
231
  apply_button = gr.Button("Apply Effect")
232
 
233
  with gr.Row():
234
- segmentation_mask = gr.Image(label="Segmentation Mask")
235
  depth_map = gr.Image(label="Depth Map")
236
 
237
  # Set up the click event
@@ -241,12 +279,16 @@ with gr.Blocks(title="Image Blur Effects") as demo:
241
  outputs=[output_image, segmentation_mask, depth_map]
242
  )
243
 
244
- # Examples section
245
- gr.Markdown("## How to use")
246
  gr.Markdown("1. Upload your image")
247
  gr.Markdown("2. Select blur type (Gaussian or Depth-based)")
248
  gr.Markdown("3. Adjust blur intensity")
249
  gr.Markdown("4. Click 'Apply Effect'")
 
 
 
 
 
250
 
251
  # Launch the demo
252
  demo.launch()
 
1
  import gradio as gr
 
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
5
  import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def apply_gaussian_blur(image, sigma=15):
8
  """Apply Gaussian blur with specified sigma value."""
9
  return cv2.GaussianBlur(image, (0, 0), sigma)
10
 
11
+ def create_foreground_mask(image):
12
+ """Create a foreground mask using simple computer vision techniques.
13
+ Returns a mask where 1 is foreground and 0 is background."""
14
+ # Convert to RGB if needed
15
+ if len(image.shape) == 2:
16
+ img_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
17
+ elif image.shape[2] == 4:
18
+ img_rgb = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
19
+ else:
20
+ img_rgb = image.copy()
21
+
22
+ # Try to use GrabCut algorithm for foreground extraction
23
  try:
24
+ # Create initial mask
25
+ mask = np.zeros(image.shape[:2], np.uint8)
26
+
27
+ # Set a rectangle focusing on the center as probable foreground
28
+ h, w = image.shape[:2]
29
+ # Adjust these values to focus more on the central figure
30
+ rect_w = int(w * 0.4)
31
+ rect_h = int(h * 0.7)
32
+ rect_x = int((w - rect_w) / 2)
33
+ rect_y = int((h - rect_h) / 2)
34
+ rect = (rect_x, rect_y, rect_w, rect_h)
35
+
36
+ # Background and foreground models
37
+ bgd_model = np.zeros((1, 65), np.float64)
38
+ fgd_model = np.zeros((1, 65), np.float64)
39
+
40
+ # Apply GrabCut
41
+ cv2.grabCut(img_rgb, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
42
+
43
+ # Create mask where certain (foreground) or probable (foreground) is 1
44
+ # 0 and 2 are background, 1 and 3 are foreground
45
+ mask2 = np.where((mask == 1) | (mask == 3), 1, 0).astype('float32')
46
+
47
+ # If the mask is almost empty, fallback to a simpler method
48
+ if mask2.sum() < (h * w * 0.05):
49
+ raise Exception("GrabCut produced an empty mask")
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Smooth the mask
52
+ mask2 = cv2.GaussianBlur(mask2, (21, 21), 7)
53
+
54
+ return mask2
55
 
56
  except Exception as e:
57
+ print(f"GrabCut failed: {e}, using fallback method")
58
+
59
+ # Try color-based segmentation as a fallback
60
+ try:
61
+ # Convert to HSV color space
62
+ hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV)
63
+
64
+ # Calculate the histogram of the central area
65
+ # This assumes the foreground object is in the center
66
+ center_y, center_x = h // 2, w // 2
67
+ center_size = min(h, w) // 4
68
+ center_y1 = max(0, center_y - center_size)
69
+ center_y2 = min(h, center_y + center_size)
70
+ center_x1 = max(0, center_x - center_size)
71
+ center_x2 = min(w, center_x + center_size)
72
+
73
+ center_hsv = hsv[center_y1:center_y2, center_x1:center_x2]
74
+
75
+ # Calculate histogram of central region (assumed to be foreground)
76
+ hist_hue = cv2.calcHist([center_hsv], [0], None, [180], [0, 180])
77
+ hist_sat = cv2.calcHist([center_hsv], [1], None, [256], [0, 256])
78
+
79
+ # Normalize histograms
80
+ cv2.normalize(hist_hue, hist_hue, 0, 255, cv2.NORM_MINMAX)
81
+ cv2.normalize(hist_sat, hist_sat, 0, 255, cv2.NORM_MINMAX)
82
+
83
+ # Create a mask based on color similarity to center region
84
+ h_channel = hsv[:,:,0]
85
+ s_channel = hsv[:,:,1]
86
+
87
+ h_score = hist_hue[h_channel.astype(int)]
88
+ s_score = hist_sat[s_channel.astype(int)]
89
+
90
+ # Combine scores
91
+ total_score = (h_score.reshape(h, w) + s_score.reshape(h, w)) / 2
92
+
93
+ # Normalize and threshold
94
+ total_score = total_score / total_score.max()
95
+ mask = (total_score > 0.4).astype(np.float32)
96
+
97
+ # Smooth the mask
98
+ mask = cv2.GaussianBlur(mask, (21, 21), 5)
99
+
100
+ return mask
101
+
102
+ except Exception as e:
103
+ print(f"Color segmentation failed: {e}, using simple mask")
104
 
105
+ # Final fallback: Create a simple radial mask focusing on center
106
  h, w = image.shape[:2]
107
  y, x = np.ogrid[:h, :w]
108
  center_y, center_x = h / 2, w / 2
 
116
 
117
  return mask
118
 
119
+ def create_depth_map(image):
120
+ """Create a depth map using simple techniques.
121
+ Lower values in the depth map mean closer to camera."""
122
+ # Convert to grayscale for processing
123
+ if len(image.shape) > 2:
124
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
125
+ else:
126
+ gray = image.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ # 1. Use Sobel for edge detection (sharp edges = closer)
129
+ sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
130
+ sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
131
+ gradient_magnitude = np.sqrt(sobelx**2 + sobely**2)
132
+ gradient_norm = gradient_magnitude / gradient_magnitude.max()
133
+
134
+ # 2. Create radial gradient from center (center = closer)
135
  h, w = image.shape[:2]
136
  y, x = np.ogrid[:h, :w]
137
  center_y, center_x = h / 2, w / 2
138
+ radial = ((x - center_x)**2 / (w/2)**2 + (y - center_y)**2 / (h/2)**2)
139
+ radial = np.clip(radial, 0, 1)
140
+
141
+ # 3. Combine (higher value = further from camera)
142
+ # Edges and center have lower values (closer)
143
+ depth = 0.7 * radial + 0.3 * (1 - gradient_norm)
144
 
145
+ # 4. Smooth the depth map
146
+ depth = cv2.GaussianBlur(depth, (21, 21), 5)
 
147
 
148
  return depth
149
 
150
+ def process_image(input_image, blur_type="Gaussian Blur", blur_intensity=15):
151
+ """Process the input image and return the results.
152
+ Ensures the foreground is kept sharp while background is blurred."""
153
  try:
154
+ # Convert from Gradio format to numpy
155
+ if isinstance(input_image, np.ndarray):
156
+ img = input_image.copy()
157
+ else:
158
+ img = np.array(input_image)
159
+
160
+ # Handle grayscale or RGBA images
161
  if img.ndim == 2: # Grayscale
162
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
163
  elif img.shape[2] == 4: # RGBA
164
  img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
165
 
166
+ # 1. Get foreground mask (1 = foreground, 0 = background)
167
+ mask = create_foreground_mask(img)
 
 
168
 
169
+ # Visualize the mask (foreground in red)
170
+ mask_vis = np.zeros_like(img)
171
+ mask_vis[:,:,0] = mask * 255 # Red channel
172
+
173
+ # 2. Get depth map (lower value = closer to camera)
174
+ depth_map = create_depth_map(img)
175
+
176
+ # Normalize depth map for visualization
177
  depth_norm = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-8)
178
  depth_vis = plt.cm.viridis(depth_norm)[:, :, :3]
179
  depth_vis = (depth_vis * 255).astype(np.uint8)
180
 
181
  # Apply appropriate blur effect
182
+ result = img.copy()
183
+
184
+ if blur_type == "Gaussian Blur":
185
+ # Apply standard Gaussian blur to the entire image
186
+ blurred_img = apply_gaussian_blur(img, sigma=blur_intensity)
187
 
188
+ # Keep the foreground sharp by using the mask:
189
+ # result = foreground + blurred background
190
+ # = mask * original + (1-mask) * blurred
191
+ for c in range(3):
192
  result[:,:,c] = mask * img[:,:,c] + (1-mask) * blurred_img[:,:,c]
193
 
194
+ else: # "Depth-based Lens Blur"
195
+ # For depth-based blur, we apply blur based on depth values
196
+ # First, combine mask with depth (ensure foreground stays sharp)
197
+ combined_depth = depth_norm.copy()
198
+ # Set foreground areas to 0 in depth map (closest)
199
+ combined_depth = combined_depth * (1 - mask)
200
 
201
+ # Create multiple blur levels based on depth
202
+ num_levels = 5
203
+ for i in range(num_levels):
204
+ # Calculate sigma for this level (further = more blur)
205
+ sigma = blur_intensity * (i + 1) / num_levels
206
  level_blurred = apply_gaussian_blur(img, sigma=sigma)
207
 
208
  # Calculate weight for this blur level
209
+ depth_min = i / num_levels
210
+ depth_max = (i + 1) / num_levels
211
+ weight = (combined_depth >= depth_min) & (combined_depth < depth_max)
212
 
213
  # Apply this blur level where applicable
214
  for c in range(3):
215
  result[:,:,c] = np.where(weight, level_blurred[:,:,c], result[:,:,c])
216
+
217
+ # Handle max depth level
218
+ max_sigma = blur_intensity
219
+ max_blurred = apply_gaussian_blur(img, sigma=max_sigma)
220
+ max_weight = (combined_depth >= ((num_levels-1) / num_levels))
221
+ for c in range(3):
222
+ result[:,:,c] = np.where(max_weight, max_blurred[:,:,c], result[:,:,c])
223
+
224
+ # Ensure foreground remains completely sharp
225
+ for c in range(3):
226
+ result[:,:,c] = mask * img[:,:,c] + (1-mask) * result[:,:,c]
227
 
228
+ # Convert to uint8 for display
229
  result = result.astype(np.uint8)
230
+ mask_vis = mask_vis.astype(np.uint8)
231
 
232
+ return result, mask_vis, depth_vis
233
 
234
  except Exception as e:
235
  print(f"Error processing image: {e}")
 
269
  apply_button = gr.Button("Apply Effect")
270
 
271
  with gr.Row():
272
+ segmentation_mask = gr.Image(label="Foreground Mask")
273
  depth_map = gr.Image(label="Depth Map")
274
 
275
  # Set up the click event
 
279
  outputs=[output_image, segmentation_mask, depth_map]
280
  )
281
 
282
+ gr.Markdown("## How to Use")
 
283
  gr.Markdown("1. Upload your image")
284
  gr.Markdown("2. Select blur type (Gaussian or Depth-based)")
285
  gr.Markdown("3. Adjust blur intensity")
286
  gr.Markdown("4. Click 'Apply Effect'")
287
+ gr.Markdown("")
288
+ gr.Markdown("### Notes")
289
+ gr.Markdown("- The red areas in the Foreground Mask show what will remain sharp")
290
+ gr.Markdown("- Depth-based blur creates a more realistic effect with blur increasing with distance")
291
+ gr.Markdown("- For best results, ensure your subject is clearly visible in the center of the image")
292
 
293
  # Launch the demo
294
  demo.launch()