Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image, ImageFilter | |
import matplotlib.pyplot as plt | |
from torchvision import transforms | |
from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation | |
def load_segmentation_model(): | |
model_name = "ZhengPeng7/BiRefNet" | |
model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True) | |
return model | |
def load_depth_model(): | |
model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf" | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForDepthEstimation.from_pretrained(model_name) | |
return processor, model | |
def process_segmentation_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor(), | |
]) | |
input_tensor = transform(image).unsqueeze(0) | |
return image, input_tensor | |
def process_depth_image(image, processor): | |
image = image.resize((512, 512)) | |
inputs = processor(images=image, return_tensors="pt") | |
return image, inputs | |
def segment_image(image, input_tensor, model): | |
with torch.no_grad(): | |
outputs = model(input_tensor) | |
output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits | |
mask = torch.sigmoid(output_tensor).squeeze().cpu().numpy() | |
mask = (mask > 0.5).astype(np.uint8) * 255 | |
return mask | |
def estimate_depth(inputs, model): | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
depth_map = outputs.predicted_depth.squeeze().cpu().numpy() | |
return depth_map | |
def normalize_depth_map(depth_map): | |
min_val = np.min(depth_map) | |
max_val = np.max(depth_map) | |
normalized_depth = (depth_map - min_val) / (max_val - min_val) | |
return normalized_depth | |
def apply_blur(image, mask): | |
mask_pil = Image.fromarray(mask).resize(image.size, Image.BILINEAR) | |
blurred_background = image.filter(ImageFilter.GaussianBlur(15)) | |
final_image = Image.composite(image, blurred_background, mask_pil) | |
return final_image | |
def apply_depth_based_blur(image, depth_map): | |
normalized_depth = normalize_depth_map(depth_map) | |
image = image.resize((512, 512)) | |
blurred_image = image.copy() | |
for y in range(image.height): | |
for x in range(image.width): | |
depth_value = float(normalized_depth[y, x]) | |
blur_radius = max(0, depth_value * 20) | |
cropped_region = image.crop((max(x-10, 0), max(y-10, 0), min(x+10, image.width), min(y+10, image.height))) | |
blurred_region = cropped_region.filter(ImageFilter.GaussianBlur(blur_radius)) | |
blurred_image.paste(blurred_region, (max(x-10, 0), max(y-10, 0))) | |
return blurred_image | |
def process_image_pipeline(image): | |
segmentation_model = load_segmentation_model() | |
depth_processor, depth_model = load_depth_model() | |
_, input_tensor = process_segmentation_image(image) | |
_, inputs = process_depth_image(image, depth_processor) | |
segmentation_mask = segment_image(image, input_tensor, segmentation_model) | |
depth_map = estimate_depth(inputs, depth_model) | |
blurred_image = apply_depth_based_blur(image, depth_map) | |
gaussian_blur_image = apply_blur(image, segmentation_mask) | |
return Image.fromarray(segmentation_mask), blurred_image, gaussian_blur_image | |
iface = gr.Interface( | |
fn=process_image_pipeline, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Image(label="Segmentation Mask"), | |
gr.Image(label="Lens Blur Effect"), | |
gr.Image(label="Gaussian Blur Effect") | |
], | |
title="Segmentation and Image Effect Processing", | |
description="Upload an image to get segmentation mask, lens blur effect, and Gaussian blur effect." | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True) | |