LiXiang12 commited on
Commit
696f3f1
·
verified ·
1 Parent(s): 38502ca
Files changed (3) hide show
  1. app-gradio.py +63 -0
  2. process.py +48 -0
  3. requirements.txt +11 -0
app-gradio.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import io
4
+ import base64
5
+ import urllib.request
6
+ from PIL import Image
7
+ from process import process
8
+
9
+ # 设备检测
10
+ DEVICE = "GPU" if torch.cuda.is_available() else "CPU"
11
+
12
+
13
+ def load_image(image, url):
14
+ """加载用户上传或URL图片"""
15
+ if image is not None:
16
+ return image
17
+ elif url:
18
+ try:
19
+ if url.startswith("http"):
20
+ with urllib.request.urlopen(url) as response:
21
+ image_data = response.read()
22
+ return Image.open(io.BytesIO(image_data))
23
+ elif url.startswith("data:image/"):
24
+ header, base64_data = url.split(",", 1)
25
+ return Image.open(io.BytesIO(base64.b64decode(base64_data)))
26
+ except Exception as e:
27
+ return None
28
+ return None
29
+
30
+
31
+ def remove_background(image):
32
+ """移除背景"""
33
+ if image is None:
34
+ return None, None
35
+ mask, image_nbg = process(image)
36
+ return mask, image_nbg
37
+
38
+
39
+ def interface(image, url):
40
+ """完整的Gradio处理流程"""
41
+ image = load_image(image, url)
42
+ if image is None:
43
+ return None, None, "请上传有效图片或输入正确的URL"
44
+ mask, image_nbg = remove_background(image)
45
+ return mask, image_nbg, "处理完成" if mask else "处理失败"
46
+
47
+
48
+ # Gradio UI
49
+ demo = gr.Interface(
50
+ fn=interface,
51
+ inputs=[gr.Image(type="pil", label="上传图片"), gr.Textbox(label="或输入图片URL")],
52
+ outputs=[
53
+ gr.Image(type="pil", label="掩码"),
54
+ gr.Image(type="pil", label="去除背景的图片"),
55
+ ],
56
+ title="AI 抠图 (RMBG 2.0)",
57
+ description="上传图片或提供URL,自动去除背景",
58
+ theme="default",
59
+ flagging_mode="never",
60
+ )
61
+
62
+ demo.queue()
63
+ demo.launch()
process.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import streamlit as st
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation
7
+
8
+
9
+ @st.cache_resource
10
+ def load_model(model_id_or_path="briaai/RMBG-2.0", precision=0, device="cuda"):
11
+ model = AutoModelForImageSegmentation.from_pretrained(
12
+ model_id_or_path, trust_remote_code=True
13
+ )
14
+ torch.set_float32_matmul_precision(["high", "highest"][precision])
15
+ model.to(device)
16
+ _ = model.eval()
17
+
18
+ # Data settings
19
+ image_size = (1024, 1024)
20
+ transform_image = transforms.Compose(
21
+ [
22
+ transforms.Resize(image_size),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
25
+ ]
26
+ )
27
+
28
+ return model, transform_image
29
+
30
+
31
+ def process(image: Image.Image) -> Image.Image:
32
+ if "RMBG-2.0" not in os.listdir("."):
33
+ os.system(
34
+ "modelscope download --model AI-ModelScope/RMBG-2.0 --local_dir RMBG-2.0 --exclude *.onnx *.bin"
35
+ )
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ precision = 0
38
+ model, transform = load_model("RMBG-2.0", precision=precision, device=device)
39
+ image = image.copy()
40
+ input_images = transform(image).unsqueeze(0).to(device)
41
+ with torch.no_grad():
42
+ preds = model(input_images)[-1].sigmoid().cpu()
43
+ pred = preds[0].squeeze()
44
+ pred_pil = transforms.ToPILImage()(pred)
45
+ mask = pred_pil.resize(image.size)
46
+ image.putalpha(mask)
47
+
48
+ return mask, image
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ kornia
5
+ transformers
6
+ streamlit
7
+ huggingface
8
+ timm
9
+ modelscope
10
+ psutil
11
+ gradio