File size: 4,598 Bytes
c3e2660
 
 
 
 
 
 
 
cdf8663
c3e2660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdf8663
c3e2660
 
 
 
 
d610685
cdf8663
 
 
 
 
 
 
 
 
 
 
 
c3e2660
 
 
 
 
 
 
 
 
 
 
d610685
 
 
 
 
 
c3e2660
 
 
 
cdf8663
c3e2660
 
 
 
 
d610685
cdf8663
 
 
 
 
 
 
 
 
c3e2660
 
 
4bc0050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3e2660
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import os
import uuid
from PIL import Image

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    # Get the segmented image (RGBA)
    processed_image = process(im)
    
    # Generate a unique filename for the processed image
    unique_id = str(uuid.uuid4())[:8]
    output_path = f"output_{unique_id}.jpg"
    
    # Create a white background and properly composite with the RGBA image
    white_bg = Image.new("RGB", processed_image.size, (255, 255, 255))
    if processed_image.mode == 'RGBA':
        # Use the alpha channel as a mask for compositing
        white_bg.paste(processed_image, mask=processed_image.split()[3])  # The 4th channel is alpha
        white_bg.save(output_path, format="JPEG")
        # Return the composited image for display to match what's being downloaded
        return white_bg, output_path
    else:
        rgb_image = processed_image.convert("RGB")
        rgb_image.save(output_path, format="JPEG")
        return rgb_image, output_path

@spaces.GPU
def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    
    # Since we're outputting RGB instead of RGBA, create a composite
    # We'll keep the image with mask for display purposes
    result = image.copy()
    result.putalpha(mask)
    return result

def process_file(f):
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    # Get the segmented image (RGBA)
    transparent = process(im)
    
    # Save as JPEG instead of PNG
    unique_id = str(uuid.uuid4())[:8]
    output_path = f"output_{unique_id}.jpg"
    
    # Create a white background and properly composite with the RGBA image
    white_bg = Image.new("RGB", transparent.size, (255, 255, 255))
    if transparent.mode == 'RGBA':
        # Use the alpha channel as a mask for compositing
        white_bg.paste(transparent, mask=transparent.split()[3])  # The 4th channel is alpha
        white_bg.save(output_path, format="JPEG")
    else:
        rgb_image = transparent.convert("RGB")
        rgb_image.save(output_path, format="JPEG")
    
    return output_path

# Using a single Blocks API instead of TabbedInterface to avoid compatibility issues
with gr.Blocks(title="Background Removal Tool") as demo:
    with gr.Tabs():
        with gr.Tab("Image Upload"):
            with gr.Row():
                image_upload = gr.Image(label="Upload an image")
            with gr.Row():
                submit_btn = gr.Button("Process Image")
            with gr.Row():
                output_image = gr.Image(label="Processed Image")
                output_file = gr.File(label="Download Processed Image")
            
            submit_btn.click(fn=fn, inputs=image_upload, outputs=[output_image, output_file])
        
        with gr.Tab("URL Input"):
            with gr.Row():
                url_input = gr.Textbox(label="Paste an image URL")
            with gr.Row():
                submit_url_btn = gr.Button("Process URL")
            with gr.Row():
                output_image_url = gr.Image(label="Processed Image")
                output_file_url = gr.File(label="Download Processed Image")
            
            submit_url_btn.click(fn=fn, inputs=url_input, outputs=[output_image_url, output_file_url])
        
        with gr.Tab("File Output"):
            with gr.Row():
                image_file_upload = gr.Image(label="Upload an image", type="filepath")
            with gr.Row():
                submit_file_btn = gr.Button("Process and Download")
            with gr.Row():
                output_file_path = gr.File(label="Download JPEG File")
            
            submit_file_btn.click(fn=process_file, inputs=image_file_upload, outputs=output_file_path)

if __name__ == "__main__":
    demo.launch(show_error=True)