|
import gradio as gr |
|
import PIL |
|
import numpy as np |
|
from models.maskclip import MaskClip |
|
from models.dino import DINO |
|
import torchvision.transforms as T |
|
import torch.nn.functional as F |
|
from lposs import lposs, lposs_plus |
|
import torch |
|
|
|
device = "cpu" |
|
if torch.cuda.is_available(): |
|
print("Using GPU") |
|
device = "cuda" |
|
|
|
|
|
|
|
print(f"Using device: {device}") |
|
|
|
maskclip = MaskClip().to(device) |
|
dino = DINO().to(device) |
|
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()]) |
|
|
|
def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]: |
|
img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device) |
|
classnames = [c.strip() for c in classnames.split(",")] |
|
num_classes = len(classnames) |
|
|
|
preds = lposs(maskclip, dino, img_tensor, classnames) |
|
if use_lposs_plus: |
|
preds = lposs_plus(img_tensor, preds) |
|
preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False) |
|
preds = F.softmax(preds * 100, dim=1).cpu().numpy() |
|
return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)]) |
|
|
|
demo = gr.Interface( |
|
fn=segment_image, |
|
inputs=["image", "text", "checkbox"], |
|
outputs=["annotatedimage"], |
|
) |
|
|
|
demo.launch() |