EEE515_Problem2 / app.py
joeWabbit's picture
Update app.py
9bcecd2 verified
raw
history blame
5.38 kB
from transformers import pipeline
from PIL import Image, ImageFilter
import gradio as gr
import torch
import numpy as np
# --- Depth-Based Blur using a Pipeline ---
# Use the pipeline for depth estimation with the small model.
depth_pipe = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
def compute_depth_map_pipeline(image: Image.Image, scale_factor: float) -> np.ndarray:
"""
Computes a depth map using the Hugging Face pipeline.
The returned depth is inverted (so near=0 and far=1) and scaled.
"""
result = depth_pipe(image) # No [0] index; the pipeline returns a dictionary
depth_map = np.array(result["depth"])
# Invert depth so that near becomes 0 and far becomes 1
depth_map = 1.0 - depth_map
depth_map *= scale_factor
return depth_map
def layered_blur(image: Image.Image, depth_map: np.ndarray, num_layers: int, max_blur: float) -> Image.Image:
"""
Applies multiple levels of Gaussian blur based on depth.
The image is blurred with increasing radii and then composited
using a mask derived from the depth map divided into bins.
"""
blur_radii = np.linspace(0, max_blur, num_layers)
blur_versions = [image.filter(ImageFilter.GaussianBlur(r)) for r in blur_radii]
upper_bound = depth_map.max()
thresholds = np.linspace(0, upper_bound, num_layers + 1)
final_image = blur_versions[-1]
for i in range(num_layers - 1, -1, -1):
mask_array = np.logical_and(depth_map >= thresholds[i],
depth_map < thresholds[i + 1]).astype(np.uint8) * 255
mask_image = Image.fromarray(mask_array, mode="L")
final_image = Image.composite(blur_versions[i], final_image, mask_image)
return final_image
def process_depth_blur_pipeline(uploaded_image, max_blur_value, scale_factor, num_layers):
"""
Processes an uploaded image using depth-based blur.
The image is resized to 512x512, its depth is computed via the pipeline,
and a layered blur is applied.
"""
if not isinstance(uploaded_image, Image.Image):
uploaded_image = Image.open(uploaded_image)
image = uploaded_image.convert("RGB").resize((512, 512))
depth_map = compute_depth_map_pipeline(image, scale_factor)
final_image = layered_blur(image, depth_map, int(num_layers), max_blur_value)
return final_image
# --- Segmentation-Based Blur using BEN2 ---
def load_segmentation_model():
"""
Loads and caches the segmentation model from BEN2.
Ensure you have ben2 installed and accessible in your path.
"""
global seg_model, seg_device
if "seg_model" not in globals():
from ben2 import BEN_Base # Import BEN2
seg_model = BEN_Base.from_pretrained("PramaLLC/BEN2")
seg_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seg_model.to(seg_device).eval()
return seg_model, seg_device
def process_segmentation_blur(uploaded_image, seg_blur_radius: float):
"""
Processes the image with segmentation-based blur.
The image is resized to 512x512. A Gaussian blur with the specified radius is applied,
then the segmentation mask is computed to composite the sharp foreground over the blurred background.
"""
if not isinstance(uploaded_image, Image.Image):
uploaded_image = Image.open(uploaded_image)
image = uploaded_image.convert("RGB").resize((512, 512))
seg_model, seg_device = load_segmentation_model()
blurred_image = image.filter(ImageFilter.GaussianBlur(seg_blur_radius))
# Generate segmentation mask (foreground)
foreground = seg_model.inference(image, refine_foreground=False)
foreground_rgba = foreground.convert("RGBA")
_, _, _, alpha = foreground_rgba.split()
binary_mask = alpha.point(lambda x: 255 if x > 128 else 0, mode="L")
final_image = Image.composite(image, blurred_image, binary_mask)
return final_image
# --- Merged Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# Depth-Based vs Segmentation-Based Blur")
with gr.Tabs():
with gr.Tab("Depth-Based Blur (Pipeline)"):
depth_img = gr.Image(type="pil", label="Upload Image")
depth_max_blur = gr.Slider(1.0, 5.0, value=3.0, step=0.1, label="Maximum Blur Radius")
depth_scale = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Depth Scale Factor")
depth_layers = gr.Slider(2, 20, value=8, step=1, label="Number of Layers")
depth_out = gr.Image(label="Depth-Based Blurred Image")
depth_button = gr.Button("Process Depth Blur")
depth_button.click(process_depth_blur_pipeline,
inputs=[depth_img, depth_max_blur, depth_scale, depth_layers],
outputs=depth_out)
with gr.Tab("Segmentation-Based Blur (BEN2)"):
seg_img = gr.Image(type="pil", label="Upload Image")
seg_blur = gr.Slider(5, 30, value=15, step=1, label="Segmentation Blur Radius")
seg_out = gr.Image(label="Segmentation-Based Blurred Image")
seg_button = gr.Button("Process Segmentation Blur")
seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out)
if __name__ == "__main__":
# Optionally, set share=True to generate a public link.
demo.launch(share=True)