gchallar's picture
Update app.py
c685a5d verified
raw
history blame
6.58 kB
import gradio as gr
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation, AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image, ImageFilter
import numpy as np
import torch
from scipy.ndimage import gaussian_filter
import cv2
# Load the OneFormer processor and model globally
oneformer_processor = None
oneformer_model = None
try:
oneformer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
oneformer_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_coco_swin_large")
except Exception as e:
print(f"Error loading OneFormer model: {e}")
# Load the Depth Estimation processor and model globally
depth_processor = None
depth_model = None
try:
depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
except Exception as e:
print(f"Error loading Depth Anything model: {e}")
def apply_gaussian_blur_background(image, mask, radius):
"""Applies Gaussian blur to the background of the image."""
blurred_background = image.filter(ImageFilter.GaussianBlur(radius=radius))
img_array = np.array(image)
blurred_array = np.array(blurred_background)
foreground_mask = mask > 0
foreground_mask_3d = np.stack([foreground_mask] * 3, axis=-1)
final_image_array = np.where(foreground_mask_3d, img_array, blurred_array)
return Image.fromarray(final_image_array.astype(np.uint8))
def apply_depth_based_blur_background(image, mask, strength):
"""Applies lens blur to the background of the image based on depth estimation."""
resized_image = image.resize((512, 512))
image_np = np.array(resized_image)
if depth_processor is None or depth_model is None:
return "Error: Depth Anything model not loaded."
# Prepare image for the depth estimation model
inputs = depth_processor(images=resized_image, return_tensors="pt")
with torch.no_grad():
outputs = depth_model(**inputs)
predicted_depth = outputs.predicted_depth
# Interpolate depth map to the resized image size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=resized_image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze().cpu().numpy()
# Normalize the depth map to the range 0-1
depth_norm = (prediction - np.min(prediction)) / (np.max(prediction) - np.min(prediction))
num_blur_levels = 5
blurred_layers = []
for i in range(num_blur_levels):
sigma = i * (strength / 5) # Adjust sigma based on strength
if sigma == 0:
blurred = image_np
else:
blurred = cv2.GaussianBlur(image_np, (15, 15), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REPLICATE)
blurred_layers.append(blurred)
depth_indices = ((1 - depth_norm) * (num_blur_levels - 1)).astype(np.uint8)
final_blurred_image_resized = np.zeros_like(image_np)
for y in range(image_np.shape[0]):
for x in range(image_np.shape[1]):
depth_index = depth_indices[y, x]
final_blurred_image_resized[y, x] = blurred_layers[depth_index][y, x]
final_blurred_pil_resized = Image.fromarray(final_blurred_image_resized.astype(np.uint8))
final_blurred_pil = final_blurred_pil_resized.resize(image.size)
final_blurred_array = np.array(final_blurred_pil)
original_array = np.array(image)
mask_resized = mask.resize(image.size)
mask_array = np.array(mask_resized) > 0
mask_array_3d = np.stack([mask_array] * 3, axis=-1)
# Apply the mask to combine the original foreground with the blurred background
final_output_array = np.where(mask_array_3d, original_array, final_blurred_array)
return Image.fromarray(final_output_array.astype(np.uint8))
def segment_and_blur(input_image, blur_type, gaussian_radius=15, lens_strength=5):
"""Segments the input image and applies the selected blur."""
if oneformer_processor is None or oneformer_model is None:
return "Error: OneFormer model not loaded."
image = input_image.convert("RGB")
# Rotate the image (assuming this is still needed)
image = image.rotate(-90, expand=True)
# Prepare input for semantic segmentation
inputs = oneformer_processor(images=image, task_inputs=["semantic"], return_tensors="pt")
# Semantic segmentation
with torch.no_grad():
outputs = oneformer_model(**inputs)
# Processing semantic segmentation output
predicted_semantic_map = oneformer_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
segmentation_mask = predicted_semantic_map.cpu().numpy()
# Get the mapping of class IDs to labels
id2label = oneformer_model.config.id2label
# Set foreground label to person
foreground_label = 'person'
foreground_class_id = None
for id, label in id2label.items():
if label == foreground_label:
foreground_class_id = id
break
if foreground_class_id is None:
return f"Error: Could not find the label '{foreground_label}' in the model's class mapping."
# Black background mask
output_mask_array = np.zeros(segmentation_mask.shape, dtype=np.uint8)
# Set the pixels corresponding to the foreground object to white (255)
output_mask_array[segmentation_mask == foreground_class_id] = 255
# Convert the NumPy array to a PIL Image
mask_pil = Image.fromarray(output_mask_array, mode='L')
if blur_type == "Gaussian":
blurred_image = apply_gaussian_blur_background(image, mask_pil, gaussian_radius)
elif blur_type == "Lens":
blurred_image = apply_depth_based_blur_background(image, mask_pil, lens_strength)
else:
return "Error: Invalid blur type selected."
return blurred_image
iface = gr.Interface(
fn=segment_and_blur,
inputs=[
gr.Image(label="Input Image"),
gr.Radio(["Gaussian", "Lens"], label="Blur Type", value="Gaussian"),
gr.Slider(0, 30, step=1, default=15, label="Gaussian Blur Radius"),
gr.Slider(0, 10, step=1, default=5, label="Lens Blur Strength"),
],
outputs=gr.Image(label="Output Image"),
title="Image Background Blur App",
description="Upload an image, select a blur type (Gaussian or Lens), and adjust the blur parameters to blur the background while keeping the person in focus."
)
if __name__ == "__main__":
iface.launch()