|
import gradio as gr |
|
import PIL |
|
import numpy as np |
|
from models.maskclip import MaskClip |
|
from models.dino import DINO |
|
import torchvision.transforms as T |
|
import torch.nn.functional as F |
|
from lposs import lposs, lposs_plus |
|
import torch |
|
import spaces |
|
|
|
device = "cpu" |
|
if torch.cuda.is_available(): |
|
print("Using GPU") |
|
device = "cuda" |
|
|
|
|
|
|
|
print(f"Using device: {device}") |
|
|
|
maskclip = MaskClip().to(device) |
|
dino = DINO().to(device) |
|
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()]) |
|
|
|
@spaces.GPU |
|
def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]: |
|
img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device) |
|
classnames = [c.strip() for c in classnames.split(",")] |
|
num_classes = len(classnames) |
|
|
|
preds = lposs(maskclip, dino, img_tensor, classnames) |
|
if use_lposs_plus: |
|
preds = lposs_plus(img_tensor, preds) |
|
preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False) |
|
preds = F.softmax(preds * 100, dim=1).cpu().numpy() |
|
return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)]) |
|
|
|
demo = gr.Interface( |
|
fn=segment_image, |
|
inputs=[ |
|
gr.Image(label="Input Image"), |
|
gr.Textbox(label="Class Names", info="Separate class names with commas"), |
|
gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+") |
|
], |
|
outputs=[ |
|
gr.AnnotatedImage(label="Segmentation Results") |
|
], |
|
title="LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation", |
|
description="Upload an image and specify the objects you want to segment by listing their names separated by commas.", |
|
) |
|
|
|
demo.launch() |