nickkun's picture
Update app.py
8cfd312 verified
raw
history blame
5.75 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Nikhil Kunjoor
"""
import gradio as gr
import numpy as np
from PIL import Image, ImageFilter
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, AutoImageProcessor, AutoModelForDepthEstimation
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_float32_matmul_precision('high')
rmbg_model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True).to(device).eval()
depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf").to(device)
def run_rmbg(image, threshold=0.5):
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = rmbg_model(input_images)
mask_logits = preds[-1]
mask_prob = mask_logits.sigmoid().cpu()[0].squeeze()
pred_pil = transforms.ToPILImage()(mask_prob)
mask_pil = pred_pil.resize(image.size, resample=Image.BILINEAR)
mask_np = np.array(mask_pil, dtype=np.uint8) / 255.0
binary_mask = (mask_np > threshold).astype(np.uint8)
return binary_mask
def run_depth_estimation(image, target_size=(512, 512)):
image_resized = image.resize(target_size, resample=Image.BILINEAR)
inputs = depth_processor(images=image_resized, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
predicted_depth = outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
depth_map = prediction.squeeze().cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
return 1 - depth_map
def apply_gaussian_blur(image, mask, sigma):
blurred = image.filter(ImageFilter.GaussianBlur(radius=sigma))
return Image.composite(image, blurred, Image.fromarray((mask * 255).astype(np.uint8)))
def apply_lens_blur(image, depth_map, max_radius, foreground_percentile):
foreground_threshold = np.percentile(depth_map.flatten(), foreground_percentile)
output = np.array(image)
for radius in np.linspace(0, max_radius, 10):
mask = (depth_map > foreground_threshold + radius / max_radius * (depth_map.max() - foreground_threshold))
blurred = image.filter(ImageFilter.GaussianBlur(radius=radius))
output[mask] = np.array(blurred)[mask]
return Image.fromarray(output)
def process_image(image, blur_type, sigma, max_radius, foreground_percentile, mask_threshold):
if image is None:
return None, "Please upload an image."
try:
image = Image.fromarray(image).convert("RGB")
except Exception as e:
return None, f"Error processing image: {str(e)}"
max_size = (1024, 1024)
if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
image.thumbnail(max_size, Image.Resampling.LANCZOS)
try:
if blur_type == "Gaussian Blur":
mask = run_rmbg(image, threshold=mask_threshold)
output_image = apply_gaussian_blur(image, mask, sigma)
else: # Lens Blur
depth_map = run_depth_estimation(image)
output_image = apply_lens_blur(image, depth_map, max_radius, foreground_percentile)
except Exception as e:
return None, f"Error applying blur: {str(e)}"
# Generate debug info
debug_info = f"Blur Type: {blur_type}\n"
if blur_type == "Gaussian Blur":
debug_info += f"Sigma: {sigma}\nMask Threshold: {mask_threshold}"
else:
debug_info += f"Max Radius: {max_radius}\nForeground Percentile: {foreground_percentile}"
return output_image, debug_info
with gr.Blocks() as demo:
gr.Markdown("# Image Blur Effects with Gaussian and Lens Blur")
with gr.Row():
image_input = gr.Image(label="Upload Image", type="numpy")
with gr.Column():
blur_type = gr.Radio(choices=["Gaussian Blur", "Lens Blur"], label="Blur Type", value="Gaussian Blur")
sigma = gr.Slider(minimum=0.1, maximum=50, step=0.1, value=15, label="Gaussian Blur Sigma")
max_radius = gr.Slider(minimum=1, maximum=100, step=1, value=15, label="Max Lens Blur Radius")
foreground_percentile = gr.Slider(minimum=1, maximum=99, step=1, value=30, label="Foreground Percentile")
mask_threshold = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, value=0.5, label="Mask Threshold")
process_button = gr.Button("Apply Blur")
with gr.Row():
output_image = gr.Image(label="Output Image")
debug_info = gr.Textbox(label="Debug Info", lines=4)
def update_visibility(blur_type):
if blur_type == "Gaussian Blur":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
else: # Lens Blur
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
blur_type.change(
fn=update_visibility,
inputs=blur_type,
outputs=[sigma, max_radius, foreground_percentile, mask_threshold]
)
process_button.click(
fn=process_image,
inputs=[image_input, blur_type, sigma, max_radius, foreground_percentile, mask_threshold],
outputs=[output_image, debug_info]
)
demo.launch()