yrosenbloom's picture
Update app.py
f48bfcf verified
raw
history blame
3.3 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import cv2
from transformers import (
SegformerFeatureExtractor, SegformerForSemanticSegmentation,
DPTFeatureExtractor, DPTForDepthEstimation
)
# Load models
seg_model_name = "nvidia/segformer-b1-finetuned-ade-512-512"
depth_model_name = "Intel/dpt-hybrid-midas"
seg_extractor = SegformerFeatureExtractor.from_pretrained(seg_model_name)
seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_name)
depth_extractor = DPTFeatureExtractor.from_pretrained(depth_model_name)
depth_model = DPTForDepthEstimation.from_pretrained(depth_model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seg_model.to(device)
depth_model.to(device)
def process_image(image_pil):
image = ImageOps.exif_transpose(image_pil).resize((512, 512)).convert("RGB")
# ---- Segmentation ----
seg_inputs = seg_extractor(images=image, return_tensors="pt", do_resize=True, do_normalize=True)
with torch.no_grad():
seg_output = seg_model(**seg_inputs.to(device)).logits
seg_mask = torch.argmax(seg_output, dim=1)[0].cpu().numpy()
binary_mask = np.where(seg_mask > 0, 255, 0).astype(np.uint8)
foreground_mask = Image.fromarray(binary_mask).convert("L")
# ---- Blur Background ----
image_rgba = image.convert("RGBA")
blurred = image.filter(ImageFilter.GaussianBlur(15)).convert("RGBA")
composite_blur = Image.composite(image_rgba, blurred, foreground_mask)
# ---- Depth ----
image_np = np.array(image)
depth_inputs = depth_extractor(images=image_np, return_tensors="pt")
with torch.no_grad():
depth_output = depth_model(**depth_inputs.to(device))
predicted_depth = depth_output.predicted_depth.squeeze().cpu().numpy()
normalized_depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
# ---- Depth-Based Blur ----
image_np = np.array(image).astype(np.float32)
resized_depth = cv2.resize(normalized_depth, (image_np.shape[1], image_np.shape[0]))
inverted_depth = 1.0 - resized_depth
blur_levels = 4
blurred_variants = []
for i in range(blur_levels):
sigma = i * 3
blurred = cv2.GaussianBlur(image_np, (15, 15), sigmaX=sigma, sigmaY=sigma) if sigma > 0 else image_np.copy()
blurred_variants.append(blurred)
blur_indices = (inverted_depth * (blur_levels - 1)).astype(np.uint8)
final_blur = np.zeros_like(image_np)
for i in range(blur_levels):
mask = (blur_indices == i)
for c in range(3):
final_blur[:, :, c][mask] = blurred_variants[i][:, :, c][mask]
lens_blur_pil = Image.fromarray(np.clip(final_blur, 0, 255).astype(np.uint8))
return image, composite_blur.convert("RGB"), lens_blur_pil
# Gradio Interface
gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(label="Original Image"),
gr.Image(label="Segmented Gaussian Blur"),
gr.Image(label="Depth-Based Lens Blur")
],
title="Visual Effects Demo: Segmentation & Depth-Based Blur",
description="Upload an image to see it segmented with background blur (like Zoom) and depth-based lens blur.",
examples=[],
).launch()