File size: 3,766 Bytes
06d49db
 
 
 
 
 
 
 
 
11fd0ad
06d49db
 
 
 
 
 
 
 
 
 
 
 
 
 
11fd0ad
06d49db
 
 
 
 
 
 
 
 
 
 
 
 
 
1b2676b
 
 
 
 
 
 
 
 
6514616
 
fd15d48
 
 
 
6514616
 
fd15d48
 
 
 
6514616
 
1b2676b
06d49db
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import PIL
import numpy as np
from models.maskclip import MaskClip
from models.dino import DINO
import torchvision.transforms as T
import torch.nn.functional as F
from lposs import lposs, lposs_plus
import torch
import spaces

device = "cpu"
if torch.cuda.is_available():
    print("Using GPU")
    device = "cuda"
# elif torch.backends.mps.is_available():
#     device = "mps"

print(f"Using device: {device}")

maskclip = MaskClip().to(device)
dino = DINO().to(device)
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])

@spaces.GPU
def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
    img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
    classnames = [c.strip() for c in classnames.split(",")]
    num_classes = len(classnames)
    
    preds = lposs(maskclip, dino, img_tensor, classnames)
    if use_lposs_plus:
        preds = lposs_plus(img_tensor, preds)
    preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
    preds = F.softmax(preds * 100, dim=1).cpu().numpy()
    return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])

demo = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(label="Input Image"),
        gr.Textbox(label="Class Names", info="Separate class names with commas"),
        gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
    ],
    outputs=[
        gr.AnnotatedImage(label="Segmentation Results")
    ],
    title="LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation",
    article="""<div align='center'>
        <a href='http://arxiv.org/abs/2503.19777' target='_blank' style='margin-right: 15px;'>
            <span style='display: inline-flex; align-items: center;'>
                <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"></path><polyline points="14 2 14 8 20 8"></polyline><line x1="16" y1="13" x2="8" y2="13"></line><line x1="16" y1="17" x2="8" y2="17"></line><polyline points="10 9 9 9 8 9"></polyline></svg>
                <span style="margin-left: 5px;">arXiv</span>
            </span>
        </a>
        <a href='https://github.com/vladan-stojnic/LPOSS' target='_blank'>
            <span style='display: inline-flex; align-items: center;'>
                <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="currentColor"><path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z"/></svg>
                <span style="margin-left: 5px;">GitHub</span>
            </span>
        </a>
    </div>""",
    description="Upload an image and specify the objects you want to segment by listing their names separated by commas.",
)

demo.launch()