nsathya5's picture
Update app.py
df258b2 verified
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import DPTImageProcessor, DPTForDepthEstimation
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from scipy.ndimage import gaussian_filter
import cv2
import os
import io
import time
# Load models globally to avoid reloading for each inference
print("Loading models...")
# Load segmentation model
try:
seg_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640")
seg_model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640")
print("βœ“ Segmentation model loaded successfully")
except Exception as e:
print(f"! Error loading segmentation model: {e}")
# Load depth estimation model
try:
depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
print("βœ“ Depth model loaded successfully")
except Exception as e:
print(f"! Error loading depth model: {e}")
# Function for image segmentation
def segment_image(image):
"""Segment the image to extract person/foreground"""
print("Running image segmentation with Segformer...")
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Original dimensions
original_size = image.size
# Use higher resolution for better results while staying within model limits
model_size = (640, 640)
model_image = image.resize(model_size, Image.LANCZOS)
# Process image with model
inputs = seg_processor(images=model_image, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = seg_model(**inputs)
logits = outputs.logits
# Extract person class (class 12 in ADE20K dataset)
person_class = 12
predicted_mask = torch.argmax(logits, dim=1)
binary_mask = (predicted_mask == person_class).cpu().numpy()[0]
# If person not found, try to find any prominent foreground object
if binary_mask.sum() < 100: # If almost no pixels were classified as person
# Try other common foreground classes
for cls in [13, 14, 15, 16, 17]: # Try vehicles, animals, etc.
cls_mask = (predicted_mask == cls).cpu().numpy()[0]
if cls_mask.sum() > binary_mask.sum():
binary_mask = cls_mask
# Convert to uint8 for OpenCV processing
mask_cv = (binary_mask * 255).astype(np.uint8)
# Apply morphological operations to clean up the mask
kernel = np.ones((5, 5), np.uint8)
mask_cv = cv2.morphologyEx(mask_cv, cv2.MORPH_CLOSE, kernel)
mask_cv = cv2.morphologyEx(mask_cv, cv2.MORPH_OPEN, kernel)
# Apply Gaussian blur to smooth the edges - less aggressive
mask_cv = cv2.GaussianBlur(mask_cv, (7, 7), 0)
_, mask_cv = cv2.threshold(mask_cv, 128, 255, cv2.THRESH_BINARY)
# Resize back to original image size using bicubic interpolation for smoother results
mask_pil = Image.fromarray(mask_cv)
mask_resized = mask_pil.resize(original_size, Image.LANCZOS)
# Convert back to numpy
mask_array = np.array(mask_resized) > 128
# Create visualization of mask (white on black background)
mask_viz = np.zeros((mask_array.shape[0], mask_array.shape[1], 3), dtype=np.uint8)
# Set all channels to the same value to create white
mask_viz[:,:,0] = mask_array * 255 # Red channel
mask_viz[:,:,1] = mask_array * 255 # Green channel
mask_viz[:,:,2] = mask_array * 255 # Blue channel
return mask_array, mask_viz
# Function to get depth map
def get_depth_map(image):
"""Get depth map from image using DPT model"""
print("Running depth estimation...")
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Original dimensions
original_size = image.size
# Higher resolution for depth estimation
model_size = (640, 640)
model_image = image.resize(model_size, Image.LANCZOS)
# Process image for depth estimation
inputs = depth_processor(images=model_image, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = depth_model(**inputs)
predicted_depth = outputs.predicted_depth
# Process depth map
depth = predicted_depth.squeeze().cpu().numpy()
depth_map = (depth - depth.min()) / (depth.max() - depth.min())
# Create colored depth map for visualization
depth_map_colored = plt.cm.viridis(depth_map)[:, :, :3]
depth_map_viz = Image.fromarray((depth_map_colored * 255).astype(np.uint8))
depth_map_viz_resized = depth_map_viz.resize(original_size, Image.LANCZOS)
# Return both visualization and raw depth map
return np.array(depth_map_viz_resized), depth_map
# Function to apply Gaussian blur to background
def apply_background_blur(image, mask, sigma=15):
"""Apply Gaussian blur to background while keeping foreground sharp"""
print(f"Applying background blur with sigma={sigma}...")
# Convert to numpy if needed
if isinstance(image, Image.Image):
image_array = np.array(image)
else:
image_array = image
# Ensure mask is binary
if mask.ndim == 3:
binary_mask = mask[:,:,0] > 0
else:
binary_mask = mask > 0
# Apply Gaussian blur to the entire image
# Use OpenCV for better performance on larger images
blurred = cv2.GaussianBlur(image_array, (0, 0), sigma)
# Combine original foreground with blurred background
result = np.copy(blurred)
for c in range(3):
result[:,:,c] = np.where(binary_mask, image_array[:,:,c], blurred[:,:,c])
return result
# Function for depth-based blur
def apply_depth_based_blur(image, mask, depth_map, max_sigma=15):
"""Apply depth-based blur using provided depth map"""
print(f"Applying depth-based blur with max_sigma={max_sigma}...")
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Original dimensions
original_size = image.size
# Resize depth map to match image size if needed
if depth_map.shape[:2] != image.size[::-1]:
depth_map_resized = cv2.resize(depth_map, original_size, interpolation=cv2.INTER_CUBIC)
else:
depth_map_resized = depth_map
# Invert depth map (closer objects should be less blurred)
inverted_depth_map = 1.0 - depth_map_resized
# Apply mask to ensure foreground is not blurred
if mask is not None:
# Ensure mask has proper dimensions
if isinstance(mask, np.ndarray):
if mask.ndim == 3:
binary_mask = mask[:,:,0] > 0
else:
binary_mask = mask > 0
else:
# Convert to numpy if needed
binary_mask = np.array(mask) > 0
# Set depth map to 0 (no blur) for foreground pixels
inverted_depth_map = inverted_depth_map * (1 - binary_mask)
# Convert to numpy array for processing
img_array = np.array(image)
# Create a progressive blur effect with multiple levels
result = np.copy(img_array)
# Apply multiple blur levels for smoother transitions
num_levels = 8
for i in range(num_levels):
# Calculate blur sigma for this level
level_sigma = max_sigma * (i + 1) / num_levels
# Create a blurred version of the image at this sigma level
level_blurred = cv2.GaussianBlur(img_array, (0, 0), level_sigma)
# Calculate where to apply this blur level
depth_min = i / num_levels
depth_max = (i + 1) / num_levels
# Create a mask for this depth range
level_mask = (inverted_depth_map >= depth_min) & (inverted_depth_map < depth_max)
# Apply this blur level
for c in range(3):
result[:,:,c] = np.where(level_mask, level_blurred[:,:,c], result[:,:,c])
return result
# Main processing function
def process_image(input_image, blur_type="Gaussian Blur", blur_intensity=15):
"""Process the input image with the selected blur effect"""
try:
# Convert from Gradio format
if not isinstance(input_image, np.ndarray):
img = np.array(input_image)
else:
img = input_image.copy()
# Ensure RGB format
if img.ndim == 2: # Grayscale
img = np.stack([img] * 3, axis=2)
elif img.shape[2] == 4: # RGBA
img = img[:, :, :3] # Drop alpha channel
# Convert to PIL for processing
pil_img = Image.fromarray(img)
# Step 1: Get segmentation mask
mask_array, mask_viz = segment_image(pil_img)
# Step 2: Always get depth map (for both blur types)
depth_viz, depth_map = get_depth_map(pil_img)
# Step 3: Apply appropriate blur effect
if blur_type == "Gaussian Blur":
# Apply regular Gaussian blur
result = apply_background_blur(pil_img, mask_array, sigma=blur_intensity)
else: # "Depth-based Lens Blur"
# Apply depth-based blur
result = apply_depth_based_blur(pil_img, mask_array, depth_map, max_sigma=blur_intensity)
return result, mask_viz, depth_viz
except Exception as e:
print(f"Error processing image: {e}")
import traceback
traceback.print_exc()
# Return original image if processing fails
if isinstance(input_image, np.ndarray):
return input_image, input_image, input_image
else:
img = np.array(input_image)
return img, img, img
# Create Gradio interface
with gr.Blocks(title="Image Blur Effects") as demo:
gr.Markdown("# Image Blur Effects App")
gr.Markdown("Upload an image to apply two types of blur effects:")
gr.Markdown("1. **Gaussian Blur**: Blurs the background while keeping the foreground sharp")
gr.Markdown("2. **Depth-based Lens Blur**: Applies varying blur intensities based on estimated depth")
with gr.Row():
input_image = gr.Image(label="Input Image", type="numpy")
output_image = gr.Image(label="Output Image")
with gr.Row():
blur_effect_type = gr.Radio(
["Gaussian Blur", "Depth-based Lens Blur"],
label="Blur Effect Type",
value="Gaussian Blur"
)
blur_intensity = gr.Slider(
minimum=1,
maximum=30,
value=15,
step=1,
label="Blur Intensity"
)
with gr.Row():
apply_button = gr.Button("Apply Effect")
with gr.Row():
foreground_mask = gr.Image(label="Foreground Mask")
depth_map = gr.Image(label="Depth Map")
# Set up the click event
apply_button.click(
process_image,
inputs=[input_image, blur_effect_type, blur_intensity],
outputs=[output_image, foreground_mask, depth_map]
)
gr.Markdown("## How to Use")
gr.Markdown("1. Upload your image")
gr.Markdown("2. Select blur type (Gaussian or Depth-based)")
gr.Markdown("3. Adjust blur intensity")
gr.Markdown("4. Click 'Apply Effect'")
gr.Markdown("")
gr.Markdown("### Notes")
gr.Markdown("- The white areas in the Foreground Mask show what will remain sharp")
gr.Markdown("- The Depth Map shows estimated distances (yellow=far, blue=close)")
gr.Markdown("- Gaussian Blur applies uniform blur to the background")
gr.Markdown("- Depth-based Blur varies blur intensity based on distance")
gr.Markdown("- Created for EEE 515 Assignment (Problem 2, Part 6)")
# Launch the demo
demo.launch()