|
|
|
|
|
""" |
|
@author: Nikhil Kunjoor |
|
""" |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image, ImageFilter |
|
import torch |
|
from torchvision import transforms |
|
from transformers import AutoModelForImageSegmentation, AutoImageProcessor, AutoModelForDepthEstimation |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
torch.set_float32_matmul_precision('high') |
|
|
|
rmbg_model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True).to(device).eval() |
|
depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf") |
|
depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf").to(device) |
|
|
|
def run_rmbg(image, threshold=0.5): |
|
image_size = (1024, 1024) |
|
transform_image = transforms.Compose([ |
|
transforms.Resize(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
input_images = transform_image(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
preds = rmbg_model(input_images) |
|
mask_logits = preds[-1] |
|
mask_prob = mask_logits.sigmoid().cpu()[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(mask_prob) |
|
mask_pil = pred_pil.resize(image.size, resample=Image.BILINEAR) |
|
mask_np = np.array(mask_pil, dtype=np.uint8) / 255.0 |
|
binary_mask = (mask_np > threshold).astype(np.uint8) |
|
return binary_mask |
|
|
|
def run_depth_estimation(image, target_size=(512, 512)): |
|
image_resized = image.resize(target_size, resample=Image.BILINEAR) |
|
inputs = depth_processor(images=image_resized, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = depth_model(**inputs) |
|
predicted_depth = outputs.predicted_depth |
|
prediction = torch.nn.functional.interpolate( |
|
predicted_depth.unsqueeze(1), |
|
size=image.size[::-1], |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
depth_map = prediction.squeeze().cpu().numpy() |
|
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) |
|
return 1 - depth_map |
|
|
|
def apply_gaussian_blur(image, mask, sigma): |
|
blurred = image.filter(ImageFilter.GaussianBlur(radius=sigma)) |
|
return Image.composite(image, blurred, Image.fromarray((mask * 255).astype(np.uint8))) |
|
|
|
def apply_lens_blur(image, depth_map, max_radius, foreground_percentile): |
|
foreground_threshold = np.percentile(depth_map.flatten(), foreground_percentile) |
|
output = np.array(image) |
|
for radius in np.linspace(0, max_radius, 10): |
|
mask = (depth_map > foreground_threshold + radius / max_radius * (depth_map.max() - foreground_threshold)) |
|
blurred = image.filter(ImageFilter.GaussianBlur(radius=radius)) |
|
output[mask] = np.array(blurred)[mask] |
|
return Image.fromarray(output) |
|
|
|
def process_image(image, blur_type, sigma, max_radius, foreground_percentile, mask_threshold): |
|
if image is None: |
|
return None, "Please upload an image." |
|
|
|
try: |
|
image = Image.fromarray(image).convert("RGB") |
|
except Exception as e: |
|
return None, f"Error processing image: {str(e)}" |
|
|
|
max_size = (1024, 1024) |
|
if image.size[0] > max_size[0] or image.size[1] > max_size[1]: |
|
image.thumbnail(max_size, Image.Resampling.LANCZOS) |
|
|
|
try: |
|
if blur_type == "Gaussian Blur": |
|
mask = run_rmbg(image, threshold=mask_threshold) |
|
output_image = apply_gaussian_blur(image, mask, sigma) |
|
else: |
|
depth_map = run_depth_estimation(image) |
|
output_image = apply_lens_blur(image, depth_map, max_radius, foreground_percentile) |
|
except Exception as e: |
|
return None, f"Error applying blur: {str(e)}" |
|
|
|
|
|
debug_info = f"Blur Type: {blur_type}\n" |
|
if blur_type == "Gaussian Blur": |
|
debug_info += f"Sigma: {sigma}\nMask Threshold: {mask_threshold}" |
|
else: |
|
debug_info += f"Max Radius: {max_radius}\nForeground Percentile: {foreground_percentile}" |
|
|
|
return output_image, debug_info |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Image Blur Effects with Gaussian and Lens Blur") |
|
with gr.Row(): |
|
image_input = gr.Image(label="Upload Image", type="numpy") |
|
with gr.Column(): |
|
blur_type = gr.Radio(choices=["Gaussian Blur", "Lens Blur"], label="Blur Type", value="Gaussian Blur") |
|
sigma = gr.Slider(minimum=0.1, maximum=50, step=0.1, value=15, label="Gaussian Blur Sigma") |
|
max_radius = gr.Slider(minimum=1, maximum=100, step=1, value=15, label="Max Lens Blur Radius") |
|
foreground_percentile = gr.Slider(minimum=1, maximum=99, step=1, value=30, label="Foreground Percentile") |
|
mask_threshold = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, value=0.5, label="Mask Threshold") |
|
|
|
process_button = gr.Button("Apply Blur") |
|
with gr.Row(): |
|
output_image = gr.Image(label="Output Image") |
|
debug_info = gr.Textbox(label="Debug Info", lines=4) |
|
|
|
def update_visibility(blur_type): |
|
if blur_type == "Gaussian Blur": |
|
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) |
|
|
|
blur_type.change( |
|
fn=update_visibility, |
|
inputs=blur_type, |
|
outputs=[sigma, max_radius, foreground_percentile, mask_threshold] |
|
) |
|
|
|
process_button.click( |
|
fn=process_image, |
|
inputs=[image_input, blur_type, sigma, max_radius, foreground_percentile, mask_threshold], |
|
outputs=[output_image, debug_info] |
|
) |
|
|
|
demo.launch() |