EEE-515-HW3Q2 / temp.py
JnanaVenkataSubhash's picture
Create temp.py
ac1b947 verified
raw
history blame contribute delete
2.26 kB
import gradio as gr
from PIL import Image, ImageFilter
import torch
from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation
import numpy as np
# Load the device (use CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the model and processor
image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
# Function to apply background blur based on depth
def apply_background_blur(image: Image):
# Convert the uploaded image to RGB if necessary
image = image.convert("RGB")
# Process the image with DepthPro model
inputs = image_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
post_processed_output = image_processor.post_process_depth_estimation(
outputs, target_sizes=[(image.height, image.width)],
)
# Get the predicted depth and normalize it
depth = post_processed_output[0]["predicted_depth"]
depth_np = depth.detach().cpu().numpy().squeeze()
depth_normalized = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min())
# Create a blurred image
blurred_image = image.copy()
# Apply variable Gaussian blur based on depth
blur_strength = 20 # You can adjust this for overall blur strength
blur_map = (depth_normalized * blur_strength).astype(int)
for radius in range(1, blur_strength + 1):
mask = (blur_map == radius)
if np.any(mask):
temp_image = image.copy()
temp_image = temp_image.filter(ImageFilter.GaussianBlur(radius))
blurred_image = Image.composite(temp_image, blurred_image, Image.fromarray((mask * 255).astype(np.uint8)))
return blurred_image
# Create Gradio interface
def create_interface():
# Gradio interface with image upload input and output for processed image
gr.Interface(
fn=apply_background_blur,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Image(type="pil", label="Blurred Image"),
live=True
).launch()
# Start the app
if __name__ == "__main__":
create_interface()