import gradio as gr import torch import numpy as np from PIL import Image, ImageFilter import matplotlib.pyplot as plt from torchvision import transforms from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation def load_segmentation_model(): try: print("Loading segmentation model...") model_name = "ZhengPeng7/BiRefNet" model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True) model.to(device) print("Segmentation model loaded successfully.") return model except Exception as e: print(f"Error loading segmentation model: {e}") return None def load_depth_model(): try: print("Loading depth estimation model...") model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf" processor = AutoProcessor.from_pretrained(model_name) model = AutoModelForDepthEstimation.from_pretrained(model_name) model.to(device) print("Depth estimation model loaded successfully.") return processor, model except Exception as e: print(f"Error loading depth estimation model: {e}") return None, None def process_segmentation_image(image): transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), ]) input_tensor = transform(image).unsqueeze(0).to(device) return image, input_tensor def process_depth_image(image, processor): image = image.resize((512, 512)) inputs = processor(images=image, return_tensors="pt").to(device) return image, inputs def segment_image(image, input_tensor, model): try: with torch.no_grad(): outputs = model(input_tensor) output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits mask = torch.sigmoid(output_tensor).squeeze().cpu().numpy() mask = (mask > 0.5).astype(np.uint8) * 255 return mask except Exception as e: print(f"Error during segmentation: {e}") return np.zeros((512, 512), dtype=np.uint8) def estimate_depth(inputs, model): try: with torch.no_grad(): outputs = model(**inputs) depth_map = outputs.predicted_depth.squeeze().cpu().numpy() return depth_map except Exception as e: print(f"Error during depth estimation: {e}") return np.zeros((512, 512), dtype=np.float32) def normalize_depth_map(depth_map): min_val = np.min(depth_map) max_val = np.max(depth_map) normalized_depth = (depth_map - min_val) / (max_val - min_val) return normalized_depth def apply_blur(image, mask): mask_pil = Image.fromarray(mask).resize(image.size, Image.BILINEAR) blurred_background = image.filter(ImageFilter.GaussianBlur(15)) final_image = Image.composite(image, blurred_background, mask_pil) return final_image def apply_depth_based_blur(image, depth_map): normalized_depth = normalize_depth_map(depth_map) image = image.resize((512, 512)) blurred_image = image.copy() for y in range(image.height): for x in range(image.width): depth_value = float(normalized_depth[y, x]) blur_radius = max(0, depth_value * 20) cropped_region = image.crop((max(x-10, 0), max(y-10, 0), min(x+10, image.width), min(y+10, image.height))) blurred_region = cropped_region.filter(ImageFilter.GaussianBlur(blur_radius)) blurred_image.paste(blurred_region, (max(x-10, 0), max(y-10, 0))) return blurred_image def process_image_pipeline(image): segmentation_model = load_segmentation_model() depth_processor, depth_model = load_depth_model() if segmentation_model is None or depth_model is None: return Image.fromarray(np.zeros((512, 512), dtype=np.uint8)), image, image _, input_tensor = process_segmentation_image(image) _, inputs = process_depth_image(image, depth_processor) segmentation_mask = segment_image(image, input_tensor, segmentation_model) depth_map = estimate_depth(inputs, depth_model) blurred_image = apply_depth_based_blur(image, depth_map) gaussian_blur_image = apply_blur(image, segmentation_mask) return Image.fromarray(segmentation_mask), blurred_image, gaussian_blur_image device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") iface = gr.Interface( fn=process_image_pipeline, inputs=gr.Image(type="pil"), outputs=[ gr.Image(label="Segmentation Mask"), gr.Image(label="Lens Blur Effect"), gr.Image(label="Gaussian Blur Effect") ], title="Segmentation and Depth-Based Image Processing", description="Upload an image to get segmentation mask, depth-based blur effect, and Gaussian blur effect." ) if __name__ == "__main__": iface.launch(share=True)