eee515 / app.py
kvinod15's picture
Update app.py
aca1feb verified
import os
import io
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision import transforms
import gradio as gr
from transformers import AutoModelForImageSegmentation, pipeline
# ----------------------------
# Global Setup and Model Loading
# ----------------------------
# Set device (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Retrieve the Hugging Face access token from environment variables.
# In your Space, you will set this in the Secrets panel.
hf_token = os.environ.get("HF_ACCESS_TOKEN")
if hf_token is None:
print("Warning: HF_ACCESS_TOKEN environment variable is not set. Model access might fail.")
# Load the segmentation model (RMBG-2.0)
# Make sure that you have been granted access to this gated model.
segmentation_model = AutoModelForImageSegmentation.from_pretrained(
'briaai/RMBG-2.0',
trust_remote_code=True,
use_auth_token=hf_token
)
segmentation_model.to(device)
segmentation_model.eval()
# Define the transformation for segmentation (resize to 512x512, convert to tensor, and normalize)
image_size = (512, 512)
segmentation_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the depth estimation pipeline (Depth-Anything)
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
# ----------------------------
# Processing Functions
# ----------------------------
def segment_and_blur_background(input_image: Image.Image, blur_radius: int = 15, threshold: float = 0.5) -> Image.Image:
"""
Uses the RMBG-2.0 segmentation model to create a binary mask,
then composites a Gaussian-blurred background with the sharp foreground.
The segmentation threshold is adjustable.
"""
# Ensure the image is in RGB and record original dimensions
image = input_image.convert("RGB")
orig_width, orig_height = image.size
# Preprocess the image for segmentation
input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
# Run inference on the segmentation model
with torch.no_grad():
preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Create a binary mask using the adjustable threshold
binary_mask = (pred > threshold).float()
mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
# Apply Gaussian blur to create the background
blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius))
# Composite the foreground with the blurred background using the mask
final_image = Image.composite(image, blurred_image, mask_pil)
return final_image
def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
"""
Applies a depth-based blur effect using a depth map from Depth-Anything.
The blur intensity is controlled by the max_blur parameter.
NOTE: This function now uses the original image size without resizing.
"""
# Use the original image for depth estimation (no resizing)
image_for_depth = input_image.convert("RGB")
# Obtain the depth map using the depth estimation pipeline
results = depth_pipeline(image_for_depth)
depth_map_image = results['depth']
# Normalize the depth map to [0, 1]
depth_array = np.array(depth_map_image, dtype=np.float32)
d_min, d_max = depth_array.min(), depth_array.max()
depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
if invert_depth:
depth_norm = 1.0 - depth_norm
# Convert the original image to RGBA for proper compositing
orig_rgba = image_for_depth.convert("RGBA")
final_image = orig_rgba.copy()
# Divide the depth range into bands and apply variable blur
band_edges = np.linspace(0, 1, num_bands + 1)
for i in range(num_bands):
band_min = band_edges[i]
band_max = band_edges[i + 1]
mid = (band_min + band_max) / 2.0
blur_radius_band = (1 - mid) * max_blur
blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
band_mask_pil = Image.fromarray(band_mask, mode="L")
final_image = Image.composite(blurred_version, final_image, band_mask_pil)
return final_image.convert("RGB")
def process_image(input_image: Image.Image, effect: str, threshold: float, blur_intensity: float) -> Image.Image:
"""
Dispatches the image processing based on the chosen effect:
- 'Gaussian Blur Background' uses segmentation with adjustable threshold and blur radius.
- 'Depth-based Lens Blur' uses a depth-based approach with adjustable blur intensity.
"""
if effect == "Gaussian Blur Background":
return segment_and_blur_background(input_image, blur_radius=int(blur_intensity), threshold=threshold)
elif effect == "Depth-based Lens Blur":
return depth_based_lens_blur(input_image, max_blur=blur_intensity)
else:
return input_image
# ----------------------------
# Gradio Interface
# ----------------------------
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(choices=["Gaussian Blur Background", "Depth-based Lens Blur"], label="Select Effect"),
gr.Slider(0.0, 1.0, value=0.5, label="Segmentation Threshold (for Gaussian Blur)"),
gr.Slider(0, 30, value=15, step=1, label="Blur Intensity (for both effects)")
],
outputs=gr.Image(type="pil", label="Output Image"),
title="EEE 515: Interactive Blur Effects Demo - by: Krishna Vinod",
description=(
"How to use this App: Upload an image and choose an effect. For 'Gaussian Blur Background', adjust the segmentation threshold and blur intensity. "
"For 'Depth-based Lens Blur', the blur intensity slider sets the maximum blur based on depth. Use the camera interface for a more interactive experience ;)"
)
)
if __name__ == "__main__":
iface.launch()