lens_blur / app.py
zxxie's picture
Update app.py
da31e5c verified
import gradio as gr
import numpy as np
import random
import cv2
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
from transformers import pipeline
import requests
from PIL import Image
# torch.set_default_device('cuda')
pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True, device = 0)
pipe_depth = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Base-hf")
def lens_blur(pil_img):
pillow_mask = pipe(pil_img, return_mask = True)
mask_arr = np.array(pillow_mask)
mask_bool = mask_arr > 220
mask_filter = ((mask_arr > 220)*255).astype(np.uint8)
mask_filter_pillow = Image.fromarray(mask_filter.astype(np.uint8))
res = cv2.bitwise_and(np.array(pil_img),np.array(pil_img),mask = mask_filter)
depth = pipe_depth(pil_img)["depth"]
depth = np.array(depth)
blurred_img_sharp = cv2.GaussianBlur(np.array(pil_img),ksize=(5,5),sigmaX=0.1,sigmaY=0.1)
blurred_img_blur = cv2.GaussianBlur(np.array(pil_img),ksize=(5,5),sigmaX=30,sigmaY=30)
blurred_with_depth_map = np.expand_dims(depth,-1)/255 * blurred_img_sharp + (255 - np.expand_dims(depth,-1))/255 * blurred_img_blur
blurred_with_depth_map = np.clip(blurred_with_depth_map ,0,255)
blurred_with_depth_map = blurred_with_depth_map.astype(np.uint8)
bg_lens_blur_img = cv2.bitwise_and(blurred_with_depth_map,blurred_with_depth_map,mask = 255 - mask_filter) + res
return Image.fromarray(bg_lens_blur_img)
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Lens Blur")
with gr.Row():
prompt = gr.Image(type="pil")
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
gr.on(
triggers=[run_button.click, prompt.upload],
fn=lens_blur,
inputs=[
prompt,
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch(share=True)