Spaces:
Sleeping
Sleeping
File size: 4,120 Bytes
2a8387c 951e3bb 61d3bf0 2a8387c 61d3bf0 2a8387c 61d3bf0 2a8387c b9bd2e1 61d3bf0 2a8387c 61d3bf0 2a8387c b9bd2e1 2a8387c 31ca554 2a8387c 31ca554 2a8387c 31ca554 2a8387c 31ca554 2a8387c 31ca554 b9bd2e1 5993699 2a8387c 5993699 2a8387c 5993699 a604886 5993699 a604886 61d3bf0 2a8387c 5993699 2a8387c 9055ebf 2a8387c 5993699 2a8387c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import torch
import gradio as gr
from PIL import Image, ImageFilter
import torchvision.transforms as transforms
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
import numpy as np
import io
# Load Models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_model_name = 'BiRefNet'
birefnet = AutoModelForImageSegmentation.from_pretrained(f'zhengpeng7/{HF_model_name}', trust_remote_code=True).to(device).eval()
print('BiRefNet (Segmentation) is ready to use.')
depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device).eval()
print('DepthPro (Blur) is ready to use.')
# Combined Image Transform
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Refine Foreground (Placeholder)
def refine_foreground(image, mask):
return image # Implement your refinement logic here
# Segmentation Function
def segment_image(image):
print("Starting segmentation with background blur...")
input_image = transform_image(image).unsqueeze(0).to(device)
print("Input image tensor shape:", input_image.shape)
with torch.no_grad():
pred = birefnet(input_image)[-1].sigmoid().cpu()[0].squeeze()
print("Prediction tensor shape:", pred.shape)
mask = transforms.ToPILImage()(pred).resize(image.size)
print("Mask PIL image size:", mask.size)
image_masked = refine_foreground(image.copy(), mask)
image_masked.putalpha(mask)
# Apply Gaussian blur to the background
blurred_background = image.copy()
blurred_background.paste((0, 0, 0, 0), mask=mask)
blurred_background = blurred_background.filter(ImageFilter.GaussianBlur(15))
blurred_background.paste(image_masked, mask=mask)
print("Segmentation with background blur completed.")
return blurred_background
# Blur Function (Rewritten)
def apply_background_blur(image: Image):
image = image.convert("RGB")
inputs = depth_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
post_processed_output = depth_processor.post_process_depth_estimation(
outputs, target_sizes=[(image.height, image.width)],
)
depth = post_processed_output[0]["predicted_depth"]
depth_np = depth.detach().cpu().numpy().squeeze()
depth_normalized = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min())
blurred_image = image.copy()
blur_strength = 20 # You can adjust this for overall blur strength
blur_map = (depth_normalized * blur_strength).astype(int)
for radius in range(1, blur_strength + 1):
mask = (blur_map == radius)
if np.any(mask):
temp_image = image.copy()
temp_image = temp_image.filter(ImageFilter.GaussianBlur(radius))
blurred_image = Image.composite(temp_image, blurred_image, Image.fromarray((mask * 255).astype(np.uint8)))
return blurred_image
# Process Image Function
def process_image(image, action):
image = image.convert("RGB")
if action == "Segmentation":
return segment_image(image)
elif action == "Blur":
return apply_background_blur(image)
elif action == "Both":
return segment_image(image), apply_background_blur(image)
else:
return None
# Gradio Interface
def gradio_interface(image, action):
result = process_image(image, action)
if action == "Both":
return result[0], result[1]
else:
return result, gr.Image(visible=False) # Return a hidden image when not needed.
interface = gr.Interface(
fn=gradio_interface,
inputs=[gr.Image(type="pil", label="Upload Image"), gr.Dropdown(["Segmentation", "Blur", "Both"], label="Select Action")],
outputs=[
gr.Image(label="Output Image 1"),
gr.Image(label="Output Image 2", visible=False)
],
live=False,
)
interface.launch() |