nickkun commited on
Commit
5f66b26
·
verified ·
1 Parent(s): 8cfd312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -119
app.py CHANGED
@@ -4,130 +4,153 @@
4
  @author: Nikhil Kunjoor
5
  """
6
  import gradio as gr
 
 
7
  import numpy as np
8
- from PIL import Image, ImageFilter
9
- import torch
10
- from torchvision import transforms
11
- from transformers import AutoModelForImageSegmentation, AutoImageProcessor, AutoModelForDepthEstimation
12
-
13
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
- torch.set_float32_matmul_precision('high')
15
-
16
- rmbg_model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True).to(device).eval()
17
- depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
18
- depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf").to(device)
19
-
20
- def run_rmbg(image, threshold=0.5):
21
- image_size = (1024, 1024)
22
- transform_image = transforms.Compose([
23
- transforms.Resize(image_size),
24
- transforms.ToTensor(),
25
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
26
- ])
27
- input_images = transform_image(image).unsqueeze(0).to(device)
28
- with torch.no_grad():
29
- preds = rmbg_model(input_images)
30
- mask_logits = preds[-1]
31
- mask_prob = mask_logits.sigmoid().cpu()[0].squeeze()
32
- pred_pil = transforms.ToPILImage()(mask_prob)
33
- mask_pil = pred_pil.resize(image.size, resample=Image.BILINEAR)
34
- mask_np = np.array(mask_pil, dtype=np.uint8) / 255.0
35
- binary_mask = (mask_np > threshold).astype(np.uint8)
36
- return binary_mask
37
-
38
- def run_depth_estimation(image, target_size=(512, 512)):
39
- image_resized = image.resize(target_size, resample=Image.BILINEAR)
40
- inputs = depth_processor(images=image_resized, return_tensors="pt").to(device)
41
- with torch.no_grad():
42
- outputs = depth_model(**inputs)
43
- predicted_depth = outputs.predicted_depth
44
- prediction = torch.nn.functional.interpolate(
45
- predicted_depth.unsqueeze(1),
46
- size=image.size[::-1],
47
- mode="bicubic",
48
- align_corners=False,
49
- )
50
- depth_map = prediction.squeeze().cpu().numpy()
51
- depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
52
- return 1 - depth_map
53
-
54
- def apply_gaussian_blur(image, mask, sigma):
55
- blurred = image.filter(ImageFilter.GaussianBlur(radius=sigma))
56
- return Image.composite(image, blurred, Image.fromarray((mask * 255).astype(np.uint8)))
57
-
58
- def apply_lens_blur(image, depth_map, max_radius, foreground_percentile):
59
- foreground_threshold = np.percentile(depth_map.flatten(), foreground_percentile)
60
- output = np.array(image)
61
- for radius in np.linspace(0, max_radius, 10):
62
- mask = (depth_map > foreground_threshold + radius / max_radius * (depth_map.max() - foreground_threshold))
63
- blurred = image.filter(ImageFilter.GaussianBlur(radius=radius))
64
- output[mask] = np.array(blurred)[mask]
65
- return Image.fromarray(output)
66
-
67
- def process_image(image, blur_type, sigma, max_radius, foreground_percentile, mask_threshold):
68
- if image is None:
69
- return None, "Please upload an image."
70
-
71
- try:
72
- image = Image.fromarray(image).convert("RGB")
73
- except Exception as e:
74
- return None, f"Error processing image: {str(e)}"
75
-
76
- max_size = (1024, 1024)
77
- if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
78
- image.thumbnail(max_size, Image.Resampling.LANCZOS)
79
-
80
- try:
81
- if blur_type == "Gaussian Blur":
82
- mask = run_rmbg(image, threshold=mask_threshold)
83
- output_image = apply_gaussian_blur(image, mask, sigma)
84
- else: # Lens Blur
85
- depth_map = run_depth_estimation(image)
86
- output_image = apply_lens_blur(image, depth_map, max_radius, foreground_percentile)
87
- except Exception as e:
88
- return None, f"Error applying blur: {str(e)}"
89
-
90
- # Generate debug info
91
- debug_info = f"Blur Type: {blur_type}\n"
92
  if blur_type == "Gaussian Blur":
93
- debug_info += f"Sigma: {sigma}\nMask Threshold: {mask_threshold}"
 
 
94
  else:
95
- debug_info += f"Max Radius: {max_radius}\nForeground Percentile: {foreground_percentile}"
96
-
97
- return output_image, debug_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
 
99
  with gr.Blocks() as demo:
100
- gr.Markdown("# Image Blur Effects with Gaussian and Lens Blur")
 
101
  with gr.Row():
102
- image_input = gr.Image(label="Upload Image", type="numpy")
103
  with gr.Column():
104
- blur_type = gr.Radio(choices=["Gaussian Blur", "Lens Blur"], label="Blur Type", value="Gaussian Blur")
105
- sigma = gr.Slider(minimum=0.1, maximum=50, step=0.1, value=15, label="Gaussian Blur Sigma")
106
- max_radius = gr.Slider(minimum=1, maximum=100, step=1, value=15, label="Max Lens Blur Radius")
107
- foreground_percentile = gr.Slider(minimum=1, maximum=99, step=1, value=30, label="Foreground Percentile")
108
- mask_threshold = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, value=0.5, label="Mask Threshold")
109
-
110
- process_button = gr.Button("Apply Blur")
111
- with gr.Row():
112
- output_image = gr.Image(label="Output Image")
113
- debug_info = gr.Textbox(label="Debug Info", lines=4)
114
-
115
- def update_visibility(blur_type):
116
- if blur_type == "Gaussian Blur":
117
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
118
- else: # Lens Blur
119
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
120
-
121
- blur_type.change(
122
- fn=update_visibility,
123
- inputs=blur_type,
124
- outputs=[sigma, max_radius, foreground_percentile, mask_threshold]
125
- )
126
-
127
- process_button.click(
128
- fn=process_image,
129
- inputs=[image_input, blur_type, sigma, max_radius, foreground_percentile, mask_threshold],
130
- outputs=[output_image, debug_info]
131
- )
132
-
133
  demo.launch()
 
4
  @author: Nikhil Kunjoor
5
  """
6
  import gradio as gr
7
+ from transformers import pipeline
8
+ from PIL import Image, ImageFilter, ImageOps
9
  import numpy as np
10
+ import requests
11
+ import cv2
12
+
13
+ # Load models once
14
+ print("Loading segmentation model...")
15
+ segmentation_model = pipeline("image-segmentation", model="nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
16
+ print("Loading depth estimation model...")
17
+ depth_estimator = pipeline("depth-estimation", model="Intel/zoedepth-nyu-kitti")
18
+
19
+ def lens_blur(image, radius):
20
+ """
21
+ Apply a more realistic lens blur (bokeh effect) using OpenCV.
22
+ """
23
+ if radius < 1:
24
+ return image
25
+
26
+ # Convert PIL image to OpenCV format
27
+ img_np = np.array(image)
28
+
29
+ # Create a circular kernel for the bokeh effect
30
+ kernel_size = 2 * radius + 1
31
+ kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
32
+ center = radius
33
+ for i in range(kernel_size):
34
+ for j in range(kernel_size):
35
+ # Create circular kernel
36
+ if np.sqrt((i - center) ** 2 + (j - center) ** 2) <= radius:
37
+ kernel[i, j] = 1.0
38
+
39
+ # Normalize the kernel
40
+ if kernel.sum() != 0:
41
+ kernel = kernel / kernel.sum()
42
+
43
+ # Apply the filter to each channel separately
44
+ channels = cv2.split(img_np)
45
+ blurred_channels = []
46
+
47
+ for channel in channels:
48
+ blurred_channel = cv2.filter2D(channel, -1, kernel)
49
+ blurred_channels.append(blurred_channel)
50
+
51
+ # Merge the channels back
52
+ blurred_img = cv2.merge(blurred_channels)
53
+
54
+ # Convert back to PIL image
55
+ return Image.fromarray(blurred_img)
56
+
57
+ def process_image(input_image, method, blur_intensity, blur_type):
58
+ """
59
+ Process the input image using one of two methods:
60
+
61
+ 1. Segmented Background Blur:
62
+ - Uses segmentation to extract a foreground mask.
63
+ - Applies the selected blur (Gaussian or Lens) to the background.
64
+ - Composites the final image.
65
+
66
+ 2. Depth-based Variable Blur:
67
+ - Uses depth estimation to generate a depth map.
68
+ - Normalizes the depth map to be used as a blending mask.
69
+ - Blends a fully blurred version (using the selected blur) with the original image.
70
+
71
+ Returns:
72
+ - output_image: final composited image.
73
+ - mask_image: the mask used (binary for segmentation, normalized depth for depth-based).
74
+ """
75
+ # Ensure image is in RGB mode
76
+ input_image = input_image.convert("RGB")
77
+
78
+ # Select blur function based on blur_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if blur_type == "Gaussian Blur":
80
+ blur_fn = lambda img, rad: img.filter(ImageFilter.GaussianBlur(radius=rad))
81
+ elif blur_type == "Lens Blur":
82
+ blur_fn = lens_blur
83
  else:
84
+ blur_fn = lambda img, rad: img.filter(ImageFilter.GaussianBlur(radius=rad))
85
+
86
+ if method == "Segmented Background Blur":
87
+ # Use segmentation to obtain a foreground mask.
88
+ results = segmentation_model(input_image)
89
+ # Assume the last result is the main foreground object.
90
+ foreground_mask = results[-1]["mask"]
91
+ # Ensure the mask is grayscale.
92
+ foreground_mask = foreground_mask.convert("L")
93
+ # Threshold to create a binary mask.
94
+ binary_mask = foreground_mask.point(lambda p: 255 if p > 128 else 0)
95
+
96
+ # Blur the background using the selected blur function.
97
+ blurred_background = blur_fn(input_image, blur_intensity)
98
+
99
+ # Composite the final image: keep foreground and use blurred background elsewhere.
100
+ output_image = Image.composite(input_image, blurred_background, binary_mask)
101
+ mask_image = binary_mask
102
+
103
+ elif method == "Depth-based Variable Blur":
104
+ # Generate depth map.
105
+ depth_results = depth_estimator(input_image)
106
+ depth_map = depth_results["depth"]
107
+
108
+ # Convert depth map to numpy array and normalize to [0, 255]
109
+ depth_array = np.array(depth_map).astype(np.float32)
110
+ norm = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min() + 1e-8)
111
+ normalized_depth = (norm * 255).astype(np.uint8)
112
+ mask_image = Image.fromarray(normalized_depth)
113
+
114
+ # Create fully blurred version using the selected blur function.
115
+ blurred_image = blur_fn(input_image, blur_intensity)
116
+
117
+ # Convert images to arrays for blending.
118
+ orig_np = np.array(input_image).astype(np.float32)
119
+ blur_np = np.array(blurred_image).astype(np.float32)
120
+ # Reshape mask for broadcasting.
121
+ alpha = normalized_depth[..., np.newaxis] / 255.0
122
+
123
+ # Blend pixels: 0 = original; 1 = fully blurred.
124
+ blended_np = (1 - alpha) * orig_np + alpha * blur_np
125
+ blended_np = np.clip(blended_np, 0, 255).astype(np.uint8)
126
+ output_image = Image.fromarray(blended_np)
127
+
128
+ else:
129
+ output_image = input_image
130
+ mask_image = input_image.convert("L")
131
+
132
+ return output_image, mask_image
133
 
134
+ # Build a Gradio interface
135
  with gr.Blocks() as demo:
136
+ gr.Markdown("## Image Processing App: Segmentation & Depth-based Blur")
137
+
138
  with gr.Row():
 
139
  with gr.Column():
140
+ input_image = gr.Image(label="Input Image", type="pil")
141
+ method = gr.Radio(label="Processing Method",
142
+ choices=["Segmented Background Blur", "Depth-based Variable Blur"],
143
+ value="Segmented Background Blur")
144
+ blur_intensity = gr.Slider(label="Blur Intensity (Maximum Blur Radius)", minimum=1, maximum=30, step=1, value=15)
145
+ blur_type = gr.Dropdown(label="Blur Type", choices=["Gaussian Blur", "Lens Blur"], value="Gaussian Blur")
146
+ run_button = gr.Button("Process Image")
147
+ with gr.Column():
148
+ output_image = gr.Image(label="Output Image")
149
+ mask_output = gr.Image(label="Mask")
150
+
151
+ run_button.click(fn=process_image,
152
+ inputs=[input_image, method, blur_intensity, blur_type],
153
+ outputs=[output_image, mask_output])
154
+
155
+ # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  demo.launch()