|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from PIL import Image, ImageFilter |
|
import matplotlib.pyplot as plt |
|
from torchvision import transforms |
|
from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation |
|
|
|
def load_segmentation_model(): |
|
try: |
|
print("Loading segmentation model...") |
|
model_name = "ZhengPeng7/BiRefNet" |
|
model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True) |
|
model.to(device) |
|
print("Segmentation model loaded successfully.") |
|
return model |
|
except Exception as e: |
|
print(f"Error loading segmentation model: {e}") |
|
return None |
|
|
|
def load_depth_model(): |
|
try: |
|
print("Loading depth estimation model...") |
|
model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf" |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = AutoModelForDepthEstimation.from_pretrained(model_name) |
|
model.to(device) |
|
print("Depth estimation model loaded successfully.") |
|
return processor, model |
|
except Exception as e: |
|
print(f"Error loading depth estimation model: {e}") |
|
return None, None |
|
|
|
def process_segmentation_image(image): |
|
transform = transforms.Compose([ |
|
transforms.Resize((512, 512)), |
|
transforms.ToTensor(), |
|
]) |
|
input_tensor = transform(image).unsqueeze(0).to(device) |
|
return image, input_tensor |
|
|
|
def process_depth_image(image, processor): |
|
image = image.resize((512, 512)) |
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
return image, inputs |
|
|
|
def segment_image(image, input_tensor, model): |
|
try: |
|
with torch.no_grad(): |
|
outputs = model(input_tensor) |
|
output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits |
|
mask = torch.sigmoid(output_tensor).squeeze().cpu().numpy() |
|
mask = (mask > 0.5).astype(np.uint8) * 255 |
|
return mask |
|
except Exception as e: |
|
print(f"Error during segmentation: {e}") |
|
return np.zeros((512, 512), dtype=np.uint8) |
|
|
|
def estimate_depth(inputs, model): |
|
try: |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
depth_map = outputs.predicted_depth.squeeze().cpu().numpy() |
|
return depth_map |
|
except Exception as e: |
|
print(f"Error during depth estimation: {e}") |
|
return np.zeros((512, 512), dtype=np.float32) |
|
|
|
def normalize_depth_map(depth_map): |
|
min_val = np.min(depth_map) |
|
max_val = np.max(depth_map) |
|
normalized_depth = (depth_map - min_val) / (max_val - min_val) |
|
return normalized_depth |
|
|
|
def apply_blur(image, mask): |
|
mask_pil = Image.fromarray(mask).resize(image.size, Image.BILINEAR) |
|
blurred_background = image.filter(ImageFilter.GaussianBlur(15)) |
|
final_image = Image.composite(image, blurred_background, mask_pil) |
|
return final_image |
|
|
|
def apply_depth_based_blur(image, depth_map): |
|
normalized_depth = normalize_depth_map(depth_map) |
|
image = image.resize((512, 512)) |
|
blurred_image = image.copy() |
|
for y in range(image.height): |
|
for x in range(image.width): |
|
depth_value = float(normalized_depth[y, x]) |
|
blur_radius = max(0, depth_value * 20) |
|
cropped_region = image.crop((max(x-10, 0), max(y-10, 0), min(x+10, image.width), min(y+10, image.height))) |
|
blurred_region = cropped_region.filter(ImageFilter.GaussianBlur(blur_radius)) |
|
blurred_image.paste(blurred_region, (max(x-10, 0), max(y-10, 0))) |
|
return blurred_image |
|
|
|
def process_image_pipeline(image): |
|
segmentation_model = load_segmentation_model() |
|
depth_processor, depth_model = load_depth_model() |
|
|
|
if segmentation_model is None or depth_model is None: |
|
return Image.fromarray(np.zeros((512, 512), dtype=np.uint8)), image, image |
|
|
|
_, input_tensor = process_segmentation_image(image) |
|
_, inputs = process_depth_image(image, depth_processor) |
|
|
|
segmentation_mask = segment_image(image, input_tensor, segmentation_model) |
|
depth_map = estimate_depth(inputs, depth_model) |
|
blurred_image = apply_depth_based_blur(image, depth_map) |
|
gaussian_blur_image = apply_blur(image, segmentation_mask) |
|
|
|
return Image.fromarray(segmentation_mask), blurred_image, gaussian_blur_image |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
iface = gr.Interface( |
|
fn=process_image_pipeline, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[ |
|
gr.Image(label="Segmentation Mask"), |
|
gr.Image(label="Lens Blur Effect"), |
|
gr.Image(label="Gaussian Blur Effect") |
|
], |
|
title="Segmentation and Depth-Based Image Processing", |
|
description="Upload an image to get segmentation mask, depth-based blur effect, and Gaussian blur effect." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch(share=True) |
|
|