Spaces:
Running
Running
import gradio as gr | |
import torch | |
import io | |
import base64 | |
import urllib.request | |
from PIL import Image | |
from process import process | |
# 设备检测 | |
DEVICE = "GPU" if torch.cuda.is_available() else "CPU" | |
def load_image(image, url): | |
"""加载用户上传或URL图片""" | |
if image is not None: | |
return image | |
elif url: | |
try: | |
if url.startswith("http"): | |
with urllib.request.urlopen(url) as response: | |
image_data = response.read() | |
return Image.open(io.BytesIO(image_data)) | |
elif url.startswith("data:image/"): | |
header, base64_data = url.split(",", 1) | |
return Image.open(io.BytesIO(base64.b64decode(base64_data))) | |
except Exception as e: | |
return None | |
return None | |
def remove_background(image): | |
"""移除背景""" | |
if image is None: | |
return None, None | |
mask, image_nbg = process(image) | |
return mask, image_nbg | |
def interface(image, url): | |
"""完整的Gradio处理流程""" | |
image = load_image(image, url) | |
if image is None: | |
return None, None, "请上传有效图片或输入正确的URL" | |
mask, image_nbg = remove_background(image) | |
return mask, image_nbg, "处理完成" if mask else "处理失败" | |
# Gradio UI | |
demo = gr.Interface( | |
fn=interface, | |
inputs=[gr.Image(type="pil", label="上传图片"), gr.Textbox(label="或输入图片URL")], | |
outputs=[ | |
gr.Image(type="pil", label="掩码"), | |
gr.Image(type="pil", label="去除背景的图片"), | |
], | |
title="AI 抠图 (RMBG 2.0)", | |
description="上传图片或提供URL,自动去除背景", | |
theme="default", | |
flagging_mode="never", | |
) | |
demo.queue() | |
demo.launch() | |