Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image, ImageFilter | |
import torch | |
from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation | |
import numpy as np | |
# Load the device (use CPU or GPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Initialize the model and processor | |
image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") | |
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device) | |
# Function to apply background blur based on depth | |
def apply_background_blur(image: Image): | |
# Convert the uploaded image to RGB if necessary | |
image = image.convert("RGB") | |
# Process the image with DepthPro model | |
inputs = image_processor(images=image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
post_processed_output = image_processor.post_process_depth_estimation( | |
outputs, target_sizes=[(image.height, image.width)], | |
) | |
# Get the predicted depth and normalize it | |
depth = post_processed_output[0]["predicted_depth"] | |
depth_np = depth.detach().cpu().numpy().squeeze() | |
depth_normalized = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min()) | |
# Create a blurred image | |
blurred_image = image.copy() | |
# Apply variable Gaussian blur based on depth | |
blur_strength = 20 # You can adjust this for overall blur strength | |
blur_map = (depth_normalized * blur_strength).astype(int) | |
for radius in range(1, blur_strength + 1): | |
mask = (blur_map == radius) | |
if np.any(mask): | |
temp_image = image.copy() | |
temp_image = temp_image.filter(ImageFilter.GaussianBlur(radius)) | |
blurred_image = Image.composite(temp_image, blurred_image, Image.fromarray((mask * 255).astype(np.uint8))) | |
return blurred_image | |
# Create Gradio interface | |
def create_interface(): | |
# Gradio interface with image upload input and output for processed image | |
gr.Interface( | |
fn=apply_background_blur, | |
inputs=gr.Image(type="pil", label="Upload Image"), | |
outputs=gr.Image(type="pil", label="Blurred Image"), | |
live=True | |
).launch() | |
# Start the app | |
if __name__ == "__main__": | |
create_interface() | |