EEE-515-HW3Q2 / app.py
JnanaVenkataSubhash's picture
Update app.py
31ca554 verified
raw
history blame contribute delete
4.12 kB
import os
import torch
import gradio as gr
from PIL import Image, ImageFilter
import torchvision.transforms as transforms
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
import numpy as np
import io
# Load Models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_model_name = 'BiRefNet'
birefnet = AutoModelForImageSegmentation.from_pretrained(f'zhengpeng7/{HF_model_name}', trust_remote_code=True).to(device).eval()
print('BiRefNet (Segmentation) is ready to use.')
depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device).eval()
print('DepthPro (Blur) is ready to use.')
# Combined Image Transform
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Refine Foreground (Placeholder)
def refine_foreground(image, mask):
return image # Implement your refinement logic here
# Segmentation Function
def segment_image(image):
print("Starting segmentation with background blur...")
input_image = transform_image(image).unsqueeze(0).to(device)
print("Input image tensor shape:", input_image.shape)
with torch.no_grad():
pred = birefnet(input_image)[-1].sigmoid().cpu()[0].squeeze()
print("Prediction tensor shape:", pred.shape)
mask = transforms.ToPILImage()(pred).resize(image.size)
print("Mask PIL image size:", mask.size)
image_masked = refine_foreground(image.copy(), mask)
image_masked.putalpha(mask)
# Apply Gaussian blur to the background
blurred_background = image.copy()
blurred_background.paste((0, 0, 0, 0), mask=mask)
blurred_background = blurred_background.filter(ImageFilter.GaussianBlur(15))
blurred_background.paste(image_masked, mask=mask)
print("Segmentation with background blur completed.")
return blurred_background
# Blur Function (Rewritten)
def apply_background_blur(image: Image):
image = image.convert("RGB")
inputs = depth_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
post_processed_output = depth_processor.post_process_depth_estimation(
outputs, target_sizes=[(image.height, image.width)],
)
depth = post_processed_output[0]["predicted_depth"]
depth_np = depth.detach().cpu().numpy().squeeze()
depth_normalized = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min())
blurred_image = image.copy()
blur_strength = 20 # You can adjust this for overall blur strength
blur_map = (depth_normalized * blur_strength).astype(int)
for radius in range(1, blur_strength + 1):
mask = (blur_map == radius)
if np.any(mask):
temp_image = image.copy()
temp_image = temp_image.filter(ImageFilter.GaussianBlur(radius))
blurred_image = Image.composite(temp_image, blurred_image, Image.fromarray((mask * 255).astype(np.uint8)))
return blurred_image
# Process Image Function
def process_image(image, action):
image = image.convert("RGB")
if action == "Segmentation":
return segment_image(image)
elif action == "Blur":
return apply_background_blur(image)
elif action == "Both":
return segment_image(image), apply_background_blur(image)
else:
return None
# Gradio Interface
def gradio_interface(image, action):
result = process_image(image, action)
if action == "Both":
return result[0], result[1]
else:
return result, gr.Image(visible=False) # Return a hidden image when not needed.
interface = gr.Interface(
fn=gradio_interface,
inputs=[gr.Image(type="pil", label="Upload Image"), gr.Dropdown(["Segmentation", "Blur", "Both"], label="Select Action")],
outputs=[
gr.Image(label="Output Image 1"),
gr.Image(label="Output Image 2", visible=False)
],
live=False,
)
interface.launch()