GaussLensBlur / app.py
Kristyyy's picture
Update app.py
d609bee verified
import numpy as np
import torch
from torchvision.transforms.functional import normalize
import torch.nn.functional as F
from skimage import io
from PIL import Image, ImageFilter
from transformers import AutoModelForImageSegmentation, AutoImageProcessor, AutoModelForDepthEstimation
import streamlit as st
# Load models
segmentation_model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
segmentation_model.to(device)
depth_model.to(device)
# Preprocessing function for segmentation
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), size=model_input_size, mode='bilinear')
image = torch.divide(im_tensor, 255.0)
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
return image
# Postprocessing function for segmentation mask
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
# Streamlit UI
st.title("Blur Effects App")
st.markdown("Choose between Gaussian blur (segmentation-based) or depth-based lens blur.")
# File uploader
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_file:
# Load the uploaded image
orig_image = Image.open(uploaded_file).convert("RGB")
orig_im_size = orig_image.size
orig_im_np = np.array(orig_image)
# Display original image
st.image(orig_image, caption="Uploaded Image", use_column_width=True)
# Effect selection
effect_option = st.radio("Choose an effect", ("Gaussian Blur (Segmentation)", "Lens Blur (Depth Map)"))
if effect_option == "Gaussian Blur (Segmentation)":
# Gaussian Blur (Segmentation-Based)
st.subheader("Gaussian Blur (Segmentation)")
# Preprocess image for segmentation
model_input_size = [256, 256] # Resize for model compatibility
image = preprocess_image(orig_im_np, model_input_size).to(device)
# Inference using the segmentation model
with torch.no_grad():
result = segmentation_model(image)
# Postprocess result to generate mask
result_image = postprocess_image(result[0][0], orig_im_size)
pil_mask_im = Image.fromarray(result_image)
# Create binary mask for background and foreground separation
binary_mask = pil_mask_im.point(lambda p: 255 if p > 170 else 0, '1')
# Gaussian blur for background
blur_intensity = st.slider("Blur Intensity (σ)", min_value=1, max_value=30, value=15)
blurred_background = orig_image.filter(ImageFilter.GaussianBlur(blur_intensity))
# Combine blurred background and sharp foreground
binary_mask_rgba = binary_mask.convert("L").resize(orig_image.size)
foreground = Image.composite(orig_image, blurred_background, binary_mask_rgba)
# Display results side by side
st.image([orig_image, foreground], caption=["Original Image", "Blurred Background"], use_column_width=True)
elif effect_option == "Lens Blur (Depth Map)":
# Lens Blur (Depth-Based)
st.subheader("Lens Blur (Depth Map)")
# Resize image for depth model processing
inputs = depth_processor(images=orig_image, return_tensors="pt")
# Generate depth map
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth.squeeze().cpu().numpy()
depth = np.array(Image.fromarray(depth).resize(orig_image.size, resample=Image.BICUBIC))
# Resize depth map back to the original image size
depth_min, depth_max = np.min(depth), np.max(depth)
manual_depth_min = depth_min/10
manual_depth_max = depth_max
normalized_depth = (depth - depth_min) / (depth_max - depth_min)
adjusted_depth = np.clip((normalized_depth * (manual_depth_max - manual_depth_min)) + manual_depth_min, 0, 1)
# Create depth-based blur effect
max_blur = st.slider("Max Blur Intensity", min_value=1, max_value=10, value=3)
depth_array = ((1 - adjusted_depth) * max_blur).astype(np.uint8)
# Create blurred images for depth-based blurring
blurred_images = [orig_image.filter(ImageFilter.GaussianBlur(i)) for i in range(max_blur + 1)]
final_image = Image.new("RGB", orig_image.size)
# Apply depth-based blur pixel by pixel
for y in range(orig_image.height):
for x in range(orig_image.width):
blur_level = min(max_blur, max(0, depth_array[y, x]))
final_image.putpixel((x, y), blurred_images[blur_level].getpixel((x, y)))
# Display results side by side
st.image([orig_image, final_image], caption=["Original Image", "Lens Blur Image"], use_column_width=True)