kvinod15 commited on
Commit
ae8d774
·
verified ·
1 Parent(s): a390814

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -60
app.py CHANGED
@@ -10,7 +10,6 @@ from transformers import AutoModelForImageSegmentation, pipeline
10
  # Global Setup and Model Loading
11
  # ----------------------------
12
 
13
- # Set device (GPU if available, else CPU)
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  # Load the segmentation model (RMBG-2.0)
@@ -21,7 +20,7 @@ segmentation_model = AutoModelForImageSegmentation.from_pretrained(
21
  segmentation_model.to(device)
22
  segmentation_model.eval()
23
 
24
- # Define the image transformation for segmentation (resize to 512x512)
25
  image_size = (512, 512)
26
  segmentation_transform = transforms.Compose([
27
  transforms.Resize(image_size),
@@ -36,122 +35,120 @@ depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anythi
36
  # Processing Functions
37
  # ----------------------------
38
 
39
- def segment_and_blur_background(input_image: Image.Image, blur_radius: int = 15, threshold: float = 0.5) -> Image.Image:
40
  """
41
- Uses the RMBG-2.0 segmentation model to create a binary mask,
42
- then composites a Gaussian-blurred background with the sharp foreground.
43
- The segmentation threshold is adjustable.
44
  """
45
- # Ensure the image is in RGB and get its original dimensions
46
  image = input_image.convert("RGB")
47
  orig_width, orig_height = image.size
48
 
49
- # Preprocess image for segmentation
50
  input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
51
 
52
- # Run inference on the segmentation model
53
  with torch.no_grad():
54
  preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
55
  pred = preds[0].squeeze()
56
 
57
- # Create a binary mask using the adjustable threshold
58
  binary_mask = (pred > threshold).float()
59
  mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
60
- # Convert grayscale mask to pure binary (0 or 255)
61
  mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
62
- # Resize mask back to the original image dimensions
63
  mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
64
 
65
- # Apply Gaussian blur to the entire image for background
66
- blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius))
67
- # Composite the original image (foreground) with the blurred background using the mask
68
  final_image = Image.composite(image, blurred_image, mask_pil)
69
  return final_image
70
 
71
-
72
  def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
73
  """
74
- Applies a depth-based blur effect using a depth map from Depth-Anything.
75
- The max_blur parameter (controlled by a slider) sets the highest blur intensity.
 
76
  """
77
- # Resize the input image to 512x512 for the depth estimation model
78
- image_resized = input_image.resize((512, 512))
79
 
80
- # Run depth estimation to obtain the depth map (as a PIL image)
81
- results = depth_pipeline(image_resized)
82
  depth_map_image = results['depth']
83
 
84
- # Convert the depth map to a NumPy array and normalize to [0, 1]
85
  depth_array = np.array(depth_map_image, dtype=np.float32)
86
  d_min, d_max = depth_array.min(), depth_array.max()
87
  depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
88
  if invert_depth:
89
  depth_norm = 1.0 - depth_norm
90
 
91
- # Convert the resized image to RGBA for compositing
92
- orig_rgba = image_resized.convert("RGBA")
93
  final_image = orig_rgba.copy()
94
 
95
- # Divide the normalized depth range into bands and apply variable blur
96
  band_edges = np.linspace(0, 1, num_bands + 1)
97
  for i in range(num_bands):
98
  band_min = band_edges[i]
99
  band_max = band_edges[i + 1]
100
- # Use the midpoint of the band to determine the blur strength.
101
  mid = (band_min + band_max) / 2.0
102
  blur_radius_band = (1 - mid) * max_blur
103
 
104
- # Create a blurred version of the image for this band.
105
  blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
106
-
107
- # Create a mask for pixels whose normalized depth falls within this band.
108
  band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
109
  band_mask_pil = Image.fromarray(band_mask, mode="L")
110
-
111
- # Composite the blurred version with the current final image using the band mask.
112
  final_image = Image.composite(blurred_version, final_image, band_mask_pil)
113
 
114
- # Return the final composited image as RGB.
115
  return final_image.convert("RGB")
116
 
117
-
118
- def process_image(input_image: Image.Image, effect: str, threshold: float, blur_intensity: float) -> Image.Image:
119
  """
120
- Dispatch function to apply the selected effect:
121
- - "Gaussian Blur Background": uses segmentation with an adjustable threshold and blur radius.
122
  - "Depth-based Lens Blur": applies depth-based blur with an adjustable maximum blur.
123
- The threshold slider is used only for the segmentation effect.
124
- The blur_intensity slider controls the blur strength in both effects.
125
  """
126
  if effect == "Gaussian Blur Background":
127
- # For segmentation, use the threshold and blur_intensity (as blur_radius)
128
- return segment_and_blur_background(input_image, blur_radius=int(blur_intensity), threshold=threshold)
129
  elif effect == "Depth-based Lens Blur":
130
- # For depth-based blur, use the blur_intensity as the max blur value.
131
- return depth_based_lens_blur(input_image, max_blur=blur_intensity)
132
  else:
133
  return input_image
134
 
135
-
136
  # ----------------------------
137
- # Gradio Interface
138
  # ----------------------------
139
 
140
- iface = gr.Interface(
141
- fn=process_image,
142
- inputs=[
143
- gr.Image(type="pil", label="Input Image"),
144
- gr.Radio(choices=["Gaussian Blur Background", "Depth-based Lens Blur"], label="Select Effect"),
145
- gr.Slider(0.0, 1.0, value=0.5, label="Segmentation Threshold (for Gaussian Blur)"),
146
- gr.Slider(0, 30, value=15, step=1, label="Blur Intensity (for both effects)")
147
- ],
148
- outputs=gr.Image(type="pil", label="Output Image"),
149
- title="Interactive Blur Effects Demo",
150
- description=(
151
- "Upload an image and choose an effect. For 'Gaussian Blur Background', adjust the segmentation threshold and blur intensity. "
152
- "For 'Depth-based Lens Blur', the blur intensity slider sets the maximum blur based on depth."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  )
154
- )
155
 
156
  if __name__ == "__main__":
157
- iface.launch()
 
10
  # Global Setup and Model Loading
11
  # ----------------------------
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # Load the segmentation model (RMBG-2.0)
 
20
  segmentation_model.to(device)
21
  segmentation_model.eval()
22
 
23
+ # Transformation for segmentation (resizes to 512 for the model input)
24
  image_size = (512, 512)
25
  segmentation_transform = transforms.Compose([
26
  transforms.Resize(image_size),
 
35
  # Processing Functions
36
  # ----------------------------
37
 
38
+ def segment_and_blur_background(input_image: Image.Image, blur_strength: int = 15, threshold: float = 0.5) -> Image.Image:
39
  """
40
+ Applies segmentation using the RMBG-2.0 model and composites the original image with
41
+ a Gaussian-blurred background based on an adjustable mask sensitivity threshold.
 
42
  """
 
43
  image = input_image.convert("RGB")
44
  orig_width, orig_height = image.size
45
 
46
+ # Preprocess image for segmentation (resize only for model inference)
47
  input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
48
 
 
49
  with torch.no_grad():
50
  preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
51
  pred = preds[0].squeeze()
52
 
53
+ # Create binary mask with adjustable threshold (mask sensitivity)
54
  binary_mask = (pred > threshold).float()
55
  mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
 
56
  mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
 
57
  mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
58
 
59
+ blurred_image = image.filter(ImageFilter.GaussianBlur(blur_strength))
 
 
60
  final_image = Image.composite(image, blurred_image, mask_pil)
61
  return final_image
62
 
 
63
  def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
64
  """
65
+ Applies a depth-based blur effect using a depth map produced by Depth-Anything.
66
+ The effect simulates a lens blur where the max_blur parameter controls the maximum blur.
67
+ This function uses the original input image size.
68
  """
69
+ # Use the original image for depth estimation (no resizing)
70
+ image_original = input_image.convert("RGB")
71
 
72
+ # Obtain depth map using the pipeline (assumes model accepts variable sizes)
73
+ results = depth_pipeline(image_original)
74
  depth_map_image = results['depth']
75
 
 
76
  depth_array = np.array(depth_map_image, dtype=np.float32)
77
  d_min, d_max = depth_array.min(), depth_array.max()
78
  depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
79
  if invert_depth:
80
  depth_norm = 1.0 - depth_norm
81
 
82
+ orig_rgba = image_original.convert("RGBA")
 
83
  final_image = orig_rgba.copy()
84
 
 
85
  band_edges = np.linspace(0, 1, num_bands + 1)
86
  for i in range(num_bands):
87
  band_min = band_edges[i]
88
  band_max = band_edges[i + 1]
 
89
  mid = (band_min + band_max) / 2.0
90
  blur_radius_band = (1 - mid) * max_blur
91
 
 
92
  blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
 
 
93
  band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
94
  band_mask_pil = Image.fromarray(band_mask, mode="L")
 
 
95
  final_image = Image.composite(blurred_version, final_image, band_mask_pil)
96
 
 
97
  return final_image.convert("RGB")
98
 
99
+ def process_image(input_image: Image.Image, effect: str, mask_sensitivity: float, blur_strength: float) -> Image.Image:
 
100
  """
101
+ Applies the selected effect:
102
+ - "Gaussian Blur Background": uses segmentation with adjustable mask sensitivity and blur strength.
103
  - "Depth-based Lens Blur": applies depth-based blur with an adjustable maximum blur.
 
 
104
  """
105
  if effect == "Gaussian Blur Background":
106
+ return segment_and_blur_background(input_image, blur_strength=int(blur_strength), threshold=mask_sensitivity)
 
107
  elif effect == "Depth-based Lens Blur":
108
+ return depth_based_lens_blur(input_image, max_blur=blur_strength)
 
109
  else:
110
  return input_image
111
 
 
112
  # ----------------------------
113
+ # Gradio Blocks Layout
114
  # ----------------------------
115
 
116
+ with gr.Blocks(title="Interactive Blur Effects Demo") as demo:
117
+ gr.Markdown(
118
+ """
119
+ # Interactive Blur Effects Demo
120
+ Upload an image and choose an effect below.
121
+ For **Gaussian Blur Background**, adjust the mask sensitivity (controls segmentation threshold)
122
+ and blur strength (controls Gaussian blur radius).
123
+ For **Depth-based Lens Blur**, the blur strength slider sets the maximum blur intensity.
124
+ """
125
+ )
126
+
127
+ with gr.Row():
128
+ with gr.Column():
129
+ input_image = gr.Image(type="pil", label="Input Image")
130
+ effect_choice = gr.Radio(
131
+ choices=["Gaussian Blur Background", "Depth-based Lens Blur"],
132
+ label="Select Effect",
133
+ value="Gaussian Blur Background"
134
+ )
135
+ mask_sensitivity_slider = gr.Slider(
136
+ minimum=0.0, maximum=1.0, value=0.5, step=0.01,
137
+ label="Mask Sensitivity (for segmentation)"
138
+ )
139
+ blur_strength_slider = gr.Slider(
140
+ minimum=0, maximum=30, value=15, step=1,
141
+ label="Blur Strength"
142
+ )
143
+ run_button = gr.Button("Apply Effect")
144
+ with gr.Column():
145
+ output_image = gr.Image(type="pil", label="Output Image")
146
+
147
+ run_button.click(
148
+ fn=process_image,
149
+ inputs=[input_image, effect_choice, mask_sensitivity_slider, blur_strength_slider],
150
+ outputs=output_image
151
  )
 
152
 
153
  if __name__ == "__main__":
154
+ demo.launch()