Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import gradio as gr | |
from PIL import Image, ImageFilter | |
import torchvision.transforms as transforms | |
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation | |
import numpy as np | |
import io | |
# Load Models | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
HF_model_name = 'BiRefNet' | |
birefnet = AutoModelForImageSegmentation.from_pretrained(f'zhengpeng7/{HF_model_name}', trust_remote_code=True).to(device).eval() | |
print('BiRefNet (Segmentation) is ready to use.') | |
depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") | |
depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device).eval() | |
print('DepthPro (Blur) is ready to use.') | |
# Combined Image Transform | |
transform_image = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Refine Foreground (Placeholder) | |
def refine_foreground(image, mask): | |
return image # Implement your refinement logic here | |
# Segmentation Function | |
def segment_image(image): | |
print("Starting segmentation with background blur...") | |
input_image = transform_image(image).unsqueeze(0).to(device) | |
print("Input image tensor shape:", input_image.shape) | |
with torch.no_grad(): | |
pred = birefnet(input_image)[-1].sigmoid().cpu()[0].squeeze() | |
print("Prediction tensor shape:", pred.shape) | |
mask = transforms.ToPILImage()(pred).resize(image.size) | |
print("Mask PIL image size:", mask.size) | |
image_masked = refine_foreground(image.copy(), mask) | |
image_masked.putalpha(mask) | |
# Apply Gaussian blur to the background | |
blurred_background = image.copy() | |
blurred_background.paste((0, 0, 0, 0), mask=mask) | |
blurred_background = blurred_background.filter(ImageFilter.GaussianBlur(15)) | |
blurred_background.paste(image_masked, mask=mask) | |
print("Segmentation with background blur completed.") | |
return blurred_background | |
# Blur Function (Rewritten) | |
def apply_background_blur(image: Image): | |
image = image.convert("RGB") | |
inputs = depth_processor(images=image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = depth_model(**inputs) | |
post_processed_output = depth_processor.post_process_depth_estimation( | |
outputs, target_sizes=[(image.height, image.width)], | |
) | |
depth = post_processed_output[0]["predicted_depth"] | |
depth_np = depth.detach().cpu().numpy().squeeze() | |
depth_normalized = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min()) | |
blurred_image = image.copy() | |
blur_strength = 20 # You can adjust this for overall blur strength | |
blur_map = (depth_normalized * blur_strength).astype(int) | |
for radius in range(1, blur_strength + 1): | |
mask = (blur_map == radius) | |
if np.any(mask): | |
temp_image = image.copy() | |
temp_image = temp_image.filter(ImageFilter.GaussianBlur(radius)) | |
blurred_image = Image.composite(temp_image, blurred_image, Image.fromarray((mask * 255).astype(np.uint8))) | |
return blurred_image | |
# Process Image Function | |
def process_image(image, action): | |
image = image.convert("RGB") | |
if action == "Segmentation": | |
return segment_image(image) | |
elif action == "Blur": | |
return apply_background_blur(image) | |
elif action == "Both": | |
return segment_image(image), apply_background_blur(image) | |
else: | |
return None | |
# Gradio Interface | |
def gradio_interface(image, action): | |
result = process_image(image, action) | |
if action == "Both": | |
return result[0], result[1] | |
else: | |
return result, gr.Image(visible=False) # Return a hidden image when not needed. | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[gr.Image(type="pil", label="Upload Image"), gr.Dropdown(["Segmentation", "Blur", "Both"], label="Select Action")], | |
outputs=[ | |
gr.Image(label="Output Image 1"), | |
gr.Image(label="Output Image 2", visible=False) | |
], | |
live=False, | |
) | |
interface.launch() |