File size: 1,421 Bytes
06d49db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
import PIL
import numpy as np
from models.maskclip import MaskClip
from models.dino import DINO
import torchvision.transforms as T
import torch.nn.functional as F
from lposs import lposs, lposs_plus
import torch

device = "cpu"
if torch.cuda.is_available():
    print("Using GPU")
    device = "cuda"
# elif torch.backends.mps.is_available():
#     device = "mps"

print(f"Using device: {device}")

maskclip = MaskClip().to(device)
dino = DINO().to(device)
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])

def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
    img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
    classnames = [c.strip() for c in classnames.split(",")]
    num_classes = len(classnames)
    
    preds = lposs(maskclip, dino, img_tensor, classnames)
    if use_lposs_plus:
        preds = lposs_plus(img_tensor, preds)
    preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
    preds = F.softmax(preds * 100, dim=1).cpu().numpy()
    return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])

demo = gr.Interface(
    fn=segment_image,
    inputs=["image", "text", "checkbox"],
    outputs=["annotatedimage"],
)

demo.launch()