nickkun commited on
Commit
8cfd312
·
verified ·
1 Parent(s): 69db8f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -66
app.py CHANGED
@@ -4,75 +4,130 @@
4
  @author: Nikhil Kunjoor
5
  """
6
  import gradio as gr
7
- from transformers import pipeline
8
- from PIL import Image, ImageFilter
9
  import numpy as np
 
10
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Load models from Hugging Face
13
- segmentation_model = pipeline("image-segmentation", model="nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
14
- depth_estimator = pipeline("depth-estimation", model="Intel/dpt-large")
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def apply_gaussian_blur(image, mask, sigma):
17
- blurred = image.filter(ImageFilter.GaussianBlur(sigma))
18
- return Image.composite(image, blurred, mask)
19
-
20
- def apply_lens_blur(image, depth_map, sigma):
21
- depth_array = np.array(depth_map)
22
- normalized_depth = (depth_array - np.min(depth_array)) / (np.max(depth_array) - np.min(depth_array))
23
-
24
- blurred = image.copy()
25
- for x in range(image.width):
26
- for y in range(image.height):
27
- blur_intensity = normalized_depth[y, x] * sigma
28
- local_blur = image.crop((x-1, y-1, x+2, y+2)).filter(ImageFilter.GaussianBlur(blur_intensity))
29
- blurred.putpixel((x, y), local_blur.getpixel((1, 1)))
30
- return blurred
31
-
32
- def process_image(image, blur_type, sigma):
33
- # Perform segmentation
34
- segmentation_results = segmentation_model(image)
35
- person_mask = None
36
- for segment in segmentation_results:
37
- if segment['label'] == 'person':
38
- person_mask = Image.fromarray((segment['mask'] * 255).astype(np.uint8))
39
- break
40
-
41
- if person_mask is None:
42
- person_mask = Image.new('L', image.size, 255) # Create a white mask if no person is detected
43
-
44
- # Perform depth estimation
45
- depth_results = depth_estimator(image)
46
- depth_map = depth_results["depth"]
47
-
48
- # Normalize depth map for visualization
49
- depth_array = np.array(depth_map)
50
- normalized_depth = (depth_array - np.min(depth_array)) / (np.max(depth_array) - np.min(depth_array)) * 255
51
- depth_visualization = Image.fromarray(normalized_depth.astype(np.uint8))
52
-
53
- # Apply selected blur effect
54
  if blur_type == "Gaussian Blur":
55
- output_image = apply_gaussian_blur(image, person_mask, sigma)
56
- else: # Lens Blur
57
- output_image = apply_lens_blur(image, depth_map, sigma)
58
-
59
- return person_mask, depth_visualization, output_image
60
-
61
- # Create Gradio interface
62
- iface = gr.Interface(
63
- fn=process_image,
64
- inputs=[
65
- gr.Image(type="pil", label="Upload Image"),
66
- gr.Radio(["Gaussian Blur", "Lens Blur"], label="Blur Type", value="Gaussian Blur"),
67
- gr.Slider(0, 50, step=1, label="Blur Intensity (Sigma)", value=15)
68
- ],
69
- outputs=[
70
- gr.Image(type="pil", label="Segmentation Mask"),
71
- gr.Image(type="pil", label="Depth Map"),
72
- gr.Image(type="pil", label="Output Image")
73
- ],
74
- title="Vision Transformer Segmentation & Depth-Based Blur Effects",
75
- description="Upload an image to apply segmentation and lens blur effects. Adjust the blur type and intensity using the controls below."
76
- )
77
-
78
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  @author: Nikhil Kunjoor
5
  """
6
  import gradio as gr
 
 
7
  import numpy as np
8
+ from PIL import Image, ImageFilter
9
  import torch
10
+ from torchvision import transforms
11
+ from transformers import AutoModelForImageSegmentation, AutoImageProcessor, AutoModelForDepthEstimation
12
+
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ torch.set_float32_matmul_precision('high')
15
+
16
+ rmbg_model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True).to(device).eval()
17
+ depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
18
+ depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf").to(device)
19
+
20
+ def run_rmbg(image, threshold=0.5):
21
+ image_size = (1024, 1024)
22
+ transform_image = transforms.Compose([
23
+ transforms.Resize(image_size),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
26
+ ])
27
+ input_images = transform_image(image).unsqueeze(0).to(device)
28
+ with torch.no_grad():
29
+ preds = rmbg_model(input_images)
30
+ mask_logits = preds[-1]
31
+ mask_prob = mask_logits.sigmoid().cpu()[0].squeeze()
32
+ pred_pil = transforms.ToPILImage()(mask_prob)
33
+ mask_pil = pred_pil.resize(image.size, resample=Image.BILINEAR)
34
+ mask_np = np.array(mask_pil, dtype=np.uint8) / 255.0
35
+ binary_mask = (mask_np > threshold).astype(np.uint8)
36
+ return binary_mask
37
 
38
+ def run_depth_estimation(image, target_size=(512, 512)):
39
+ image_resized = image.resize(target_size, resample=Image.BILINEAR)
40
+ inputs = depth_processor(images=image_resized, return_tensors="pt").to(device)
41
+ with torch.no_grad():
42
+ outputs = depth_model(**inputs)
43
+ predicted_depth = outputs.predicted_depth
44
+ prediction = torch.nn.functional.interpolate(
45
+ predicted_depth.unsqueeze(1),
46
+ size=image.size[::-1],
47
+ mode="bicubic",
48
+ align_corners=False,
49
+ )
50
+ depth_map = prediction.squeeze().cpu().numpy()
51
+ depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
52
+ return 1 - depth_map
53
 
54
  def apply_gaussian_blur(image, mask, sigma):
55
+ blurred = image.filter(ImageFilter.GaussianBlur(radius=sigma))
56
+ return Image.composite(image, blurred, Image.fromarray((mask * 255).astype(np.uint8)))
57
+
58
+ def apply_lens_blur(image, depth_map, max_radius, foreground_percentile):
59
+ foreground_threshold = np.percentile(depth_map.flatten(), foreground_percentile)
60
+ output = np.array(image)
61
+ for radius in np.linspace(0, max_radius, 10):
62
+ mask = (depth_map > foreground_threshold + radius / max_radius * (depth_map.max() - foreground_threshold))
63
+ blurred = image.filter(ImageFilter.GaussianBlur(radius=radius))
64
+ output[mask] = np.array(blurred)[mask]
65
+ return Image.fromarray(output)
66
+
67
+ def process_image(image, blur_type, sigma, max_radius, foreground_percentile, mask_threshold):
68
+ if image is None:
69
+ return None, "Please upload an image."
70
+
71
+ try:
72
+ image = Image.fromarray(image).convert("RGB")
73
+ except Exception as e:
74
+ return None, f"Error processing image: {str(e)}"
75
+
76
+ max_size = (1024, 1024)
77
+ if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
78
+ image.thumbnail(max_size, Image.Resampling.LANCZOS)
79
+
80
+ try:
81
+ if blur_type == "Gaussian Blur":
82
+ mask = run_rmbg(image, threshold=mask_threshold)
83
+ output_image = apply_gaussian_blur(image, mask, sigma)
84
+ else: # Lens Blur
85
+ depth_map = run_depth_estimation(image)
86
+ output_image = apply_lens_blur(image, depth_map, max_radius, foreground_percentile)
87
+ except Exception as e:
88
+ return None, f"Error applying blur: {str(e)}"
89
+
90
+ # Generate debug info
91
+ debug_info = f"Blur Type: {blur_type}\n"
92
  if blur_type == "Gaussian Blur":
93
+ debug_info += f"Sigma: {sigma}\nMask Threshold: {mask_threshold}"
94
+ else:
95
+ debug_info += f"Max Radius: {max_radius}\nForeground Percentile: {foreground_percentile}"
96
+
97
+ return output_image, debug_info
98
+
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("# Image Blur Effects with Gaussian and Lens Blur")
101
+ with gr.Row():
102
+ image_input = gr.Image(label="Upload Image", type="numpy")
103
+ with gr.Column():
104
+ blur_type = gr.Radio(choices=["Gaussian Blur", "Lens Blur"], label="Blur Type", value="Gaussian Blur")
105
+ sigma = gr.Slider(minimum=0.1, maximum=50, step=0.1, value=15, label="Gaussian Blur Sigma")
106
+ max_radius = gr.Slider(minimum=1, maximum=100, step=1, value=15, label="Max Lens Blur Radius")
107
+ foreground_percentile = gr.Slider(minimum=1, maximum=99, step=1, value=30, label="Foreground Percentile")
108
+ mask_threshold = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, value=0.5, label="Mask Threshold")
109
+
110
+ process_button = gr.Button("Apply Blur")
111
+ with gr.Row():
112
+ output_image = gr.Image(label="Output Image")
113
+ debug_info = gr.Textbox(label="Debug Info", lines=4)
114
+
115
+ def update_visibility(blur_type):
116
+ if blur_type == "Gaussian Blur":
117
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
118
+ else: # Lens Blur
119
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
120
+
121
+ blur_type.change(
122
+ fn=update_visibility,
123
+ inputs=blur_type,
124
+ outputs=[sigma, max_radius, foreground_percentile, mask_threshold]
125
+ )
126
+
127
+ process_button.click(
128
+ fn=process_image,
129
+ inputs=[image_input, blur_type, sigma, max_radius, foreground_percentile, mask_threshold],
130
+ outputs=[output_image, debug_info]
131
+ )
132
+
133
+ demo.launch()