kvinod15 commited on
Commit
e86315f
·
verified ·
1 Parent(s): ae8d774

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -57
app.py CHANGED
@@ -10,6 +10,7 @@ from transformers import AutoModelForImageSegmentation, pipeline
10
  # Global Setup and Model Loading
11
  # ----------------------------
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # Load the segmentation model (RMBG-2.0)
@@ -20,7 +21,7 @@ segmentation_model = AutoModelForImageSegmentation.from_pretrained(
20
  segmentation_model.to(device)
21
  segmentation_model.eval()
22
 
23
- # Transformation for segmentation (resizes to 512 for the model input)
24
  image_size = (512, 512)
25
  segmentation_transform = transforms.Compose([
26
  transforms.Resize(image_size),
@@ -35,120 +36,119 @@ depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anythi
35
  # Processing Functions
36
  # ----------------------------
37
 
38
- def segment_and_blur_background(input_image: Image.Image, blur_strength: int = 15, threshold: float = 0.5) -> Image.Image:
39
  """
40
- Applies segmentation using the RMBG-2.0 model and composites the original image with
41
- a Gaussian-blurred background based on an adjustable mask sensitivity threshold.
 
42
  """
 
43
  image = input_image.convert("RGB")
44
  orig_width, orig_height = image.size
45
 
46
- # Preprocess image for segmentation (resize only for model inference)
47
  input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
48
 
 
49
  with torch.no_grad():
50
  preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
51
  pred = preds[0].squeeze()
52
 
53
- # Create binary mask with adjustable threshold (mask sensitivity)
54
  binary_mask = (pred > threshold).float()
55
  mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
 
56
  mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
 
57
  mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
58
 
59
- blurred_image = image.filter(ImageFilter.GaussianBlur(blur_strength))
 
 
60
  final_image = Image.composite(image, blurred_image, mask_pil)
61
  return final_image
62
 
63
  def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
64
  """
65
- Applies a depth-based blur effect using a depth map produced by Depth-Anything.
66
- The effect simulates a lens blur where the max_blur parameter controls the maximum blur.
67
- This function uses the original input image size.
68
  """
69
- # Use the original image for depth estimation (no resizing)
70
- image_original = input_image.convert("RGB")
71
 
72
- # Obtain depth map using the pipeline (assumes model accepts variable sizes)
73
- results = depth_pipeline(image_original)
74
  depth_map_image = results['depth']
75
 
 
76
  depth_array = np.array(depth_map_image, dtype=np.float32)
77
  d_min, d_max = depth_array.min(), depth_array.max()
78
  depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
79
  if invert_depth:
80
  depth_norm = 1.0 - depth_norm
81
 
82
- orig_rgba = image_original.convert("RGBA")
 
83
  final_image = orig_rgba.copy()
84
 
 
85
  band_edges = np.linspace(0, 1, num_bands + 1)
86
  for i in range(num_bands):
87
  band_min = band_edges[i]
88
  band_max = band_edges[i + 1]
 
89
  mid = (band_min + band_max) / 2.0
90
  blur_radius_band = (1 - mid) * max_blur
91
 
 
92
  blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
 
 
93
  band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
94
  band_mask_pil = Image.fromarray(band_mask, mode="L")
 
 
95
  final_image = Image.composite(blurred_version, final_image, band_mask_pil)
96
 
 
97
  return final_image.convert("RGB")
98
 
99
- def process_image(input_image: Image.Image, effect: str, mask_sensitivity: float, blur_strength: float) -> Image.Image:
100
  """
101
- Applies the selected effect:
102
- - "Gaussian Blur Background": uses segmentation with adjustable mask sensitivity and blur strength.
103
  - "Depth-based Lens Blur": applies depth-based blur with an adjustable maximum blur.
 
 
104
  """
105
  if effect == "Gaussian Blur Background":
106
- return segment_and_blur_background(input_image, blur_strength=int(blur_strength), threshold=mask_sensitivity)
 
107
  elif effect == "Depth-based Lens Blur":
108
- return depth_based_lens_blur(input_image, max_blur=blur_strength)
 
109
  else:
110
  return input_image
111
 
112
  # ----------------------------
113
- # Gradio Blocks Layout
114
  # ----------------------------
115
 
116
- with gr.Blocks(title="Interactive Blur Effects Demo") as demo:
117
- gr.Markdown(
118
- """
119
- # Interactive Blur Effects Demo
120
- Upload an image and choose an effect below.
121
- For **Gaussian Blur Background**, adjust the mask sensitivity (controls segmentation threshold)
122
- and blur strength (controls Gaussian blur radius).
123
- For **Depth-based Lens Blur**, the blur strength slider sets the maximum blur intensity.
124
- """
125
- )
126
-
127
- with gr.Row():
128
- with gr.Column():
129
- input_image = gr.Image(type="pil", label="Input Image")
130
- effect_choice = gr.Radio(
131
- choices=["Gaussian Blur Background", "Depth-based Lens Blur"],
132
- label="Select Effect",
133
- value="Gaussian Blur Background"
134
- )
135
- mask_sensitivity_slider = gr.Slider(
136
- minimum=0.0, maximum=1.0, value=0.5, step=0.01,
137
- label="Mask Sensitivity (for segmentation)"
138
- )
139
- blur_strength_slider = gr.Slider(
140
- minimum=0, maximum=30, value=15, step=1,
141
- label="Blur Strength"
142
- )
143
- run_button = gr.Button("Apply Effect")
144
- with gr.Column():
145
- output_image = gr.Image(type="pil", label="Output Image")
146
-
147
- run_button.click(
148
- fn=process_image,
149
- inputs=[input_image, effect_choice, mask_sensitivity_slider, blur_strength_slider],
150
- outputs=output_image
151
  )
 
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
10
  # Global Setup and Model Loading
11
  # ----------------------------
12
 
13
+ # Set device (GPU if available, else CPU)
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  # Load the segmentation model (RMBG-2.0)
 
21
  segmentation_model.to(device)
22
  segmentation_model.eval()
23
 
24
+ # Define the image transformation for segmentation (resize to 512x512, then normalize)
25
  image_size = (512, 512)
26
  segmentation_transform = transforms.Compose([
27
  transforms.Resize(image_size),
 
36
  # Processing Functions
37
  # ----------------------------
38
 
39
+ def segment_and_blur_background(input_image: Image.Image, blur_radius: int = 15, threshold: float = 0.5) -> Image.Image:
40
  """
41
+ Uses the RMBG-2.0 segmentation model to create a binary mask,
42
+ then composites a Gaussian-blurred background with the sharp foreground.
43
+ The segmentation threshold is adjustable.
44
  """
45
+ # Ensure the image is in RGB and get its original dimensions
46
  image = input_image.convert("RGB")
47
  orig_width, orig_height = image.size
48
 
49
+ # Preprocess image for segmentation
50
  input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
51
 
52
+ # Run inference on the segmentation model
53
  with torch.no_grad():
54
  preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
55
  pred = preds[0].squeeze()
56
 
57
+ # Create a binary mask using the adjustable threshold
58
  binary_mask = (pred > threshold).float()
59
  mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
60
+ # Convert grayscale mask to pure binary (0 or 255)
61
  mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
62
+ # Resize mask back to the original image dimensions
63
  mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
64
 
65
+ # Apply Gaussian blur to the entire image for background
66
+ blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius))
67
+ # Composite the original image (foreground) with the blurred background using the mask
68
  final_image = Image.composite(image, blurred_image, mask_pil)
69
  return final_image
70
 
71
  def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
72
  """
73
+ Applies a depth-based blur effect using a depth map from Depth-Anything.
74
+ The max_blur parameter (controlled by a slider) sets the highest blur intensity.
 
75
  """
76
+ # Resize the input image to 512x512 for the depth estimation model
77
+ image_resized = input_image.resize((512, 512))
78
 
79
+ # Run depth estimation to obtain the depth map (as a PIL image)
80
+ results = depth_pipeline(image_resized)
81
  depth_map_image = results['depth']
82
 
83
+ # Convert the depth map to a NumPy array and normalize to [0, 1]
84
  depth_array = np.array(depth_map_image, dtype=np.float32)
85
  d_min, d_max = depth_array.min(), depth_array.max()
86
  depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
87
  if invert_depth:
88
  depth_norm = 1.0 - depth_norm
89
 
90
+ # Convert the resized image to RGBA for compositing
91
+ orig_rgba = image_resized.convert("RGBA")
92
  final_image = orig_rgba.copy()
93
 
94
+ # Divide the normalized depth range into bands and apply variable blur
95
  band_edges = np.linspace(0, 1, num_bands + 1)
96
  for i in range(num_bands):
97
  band_min = band_edges[i]
98
  band_max = band_edges[i + 1]
99
+ # Use the midpoint of the band to determine the blur strength.
100
  mid = (band_min + band_max) / 2.0
101
  blur_radius_band = (1 - mid) * max_blur
102
 
103
+ # Create a blurred version of the image for this band.
104
  blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
105
+
106
+ # Create a mask for pixels whose normalized depth falls within this band.
107
  band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
108
  band_mask_pil = Image.fromarray(band_mask, mode="L")
109
+
110
+ # Composite the blurred version with the current final image using the band mask.
111
  final_image = Image.composite(blurred_version, final_image, band_mask_pil)
112
 
113
+ # Return the final composited image as RGB.
114
  return final_image.convert("RGB")
115
 
116
+ def process_image(input_image: Image.Image, effect: str, threshold: float, blur_intensity: float) -> Image.Image:
117
  """
118
+ Dispatch function to apply the selected effect:
119
+ - "Gaussian Blur Background": uses segmentation with an adjustable threshold and blur radius.
120
  - "Depth-based Lens Blur": applies depth-based blur with an adjustable maximum blur.
121
+ The threshold slider is used only for the segmentation effect.
122
+ The blur_intensity slider controls the blur strength in both effects.
123
  """
124
  if effect == "Gaussian Blur Background":
125
+ # For segmentation, use the threshold and blur_intensity (as blur_radius)
126
+ return segment_and_blur_background(input_image, blur_radius=int(blur_intensity), threshold=threshold)
127
  elif effect == "Depth-based Lens Blur":
128
+ # For depth-based blur, use the blur_intensity as the max blur value.
129
+ return depth_based_lens_blur(input_image, max_blur=blur_intensity)
130
  else:
131
  return input_image
132
 
133
  # ----------------------------
134
+ # Gradio Interface
135
  # ----------------------------
136
 
137
+ iface = gr.Interface(
138
+ fn=process_image,
139
+ inputs=[
140
+ gr.Image(type="pil", label="Input Image"),
141
+ gr.Radio(choices=["Gaussian Blur Background", "Depth-based Lens Blur"], label="Select Effect"),
142
+ gr.Slider(0.0, 1.0, value=0.5, label="Segmentation Threshold (for Gaussian Blur)"),
143
+ gr.Slider(0, 30, value=15, step=1, label="Blur Intensity (for both effects)")
144
+ ],
145
+ outputs=gr.Image(type="pil", label="Output Image"),
146
+ title="Interactive Blur Effects Demo",
147
+ description=(
148
+ "Upload an image and choose an effect. For 'Gaussian Blur Background', adjust the segmentation threshold and blur intensity. "
149
+ "For 'Depth-based Lens Blur', the blur intensity slider sets the maximum blur based on depth."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  )
151
+ )
152
 
153
  if __name__ == "__main__":
154
+ iface.launch()