File size: 4,889 Bytes
06d49db 11fd0ad 06d49db 6fe4585 be73ac3 6fe4585 be73ac3 6fe4585 11fd0ad 6fe4585 be73ac3 06d49db 6fe4585 06d49db 6fe4585 06d49db be73ac3 06d49db 6fe4585 4370bbd 6514616 4370bbd 6514616 6fe4585 38737c2 6fe4585 38737c2 6fe4585 be73ac3 121a9f3 be73ac3 6fe4585 38737c2 6fe4585 38737c2 6fe4585 be73ac3 6fe4585 be73ac3 6fe4585 06d49db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import gradio as gr
import PIL
import numpy as np
from models.maskclip import MaskClip
from models.dino import DINO
import torchvision.transforms as T
import torch.nn.functional as F
from lposs import lposs, lposs_plus
import torch
import spaces
device = "cpu"
if torch.cuda.is_available():
print("Using GPU")
device = "cuda"
# elif torch.backends.mps.is_available():
# device = "mps"
print(f"Using device: {device}")
maskclip = MaskClip().to(device)
dino = DINO().to(device)
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])
# Default hyperparameter values
DEFAULT_SIGMA = 100
DEFAULT_ALPHA = 0.95
DEFAULT_K = 400
DEFAULT_WSIZE = 224
DEFAULT_GAMMA = 3.0
DEFAULT_TAU = 0.01
DEFAULT_R = 13
# Function to reset hyperparameters to default values
def reset_hyperparams():
return DEFAULT_WSIZE, DEFAULT_K, DEFAULT_GAMMA, DEFAULT_ALPHA, DEFAULT_SIGMA, DEFAULT_TAU, DEFAULT_R
@spaces.GPU
def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None,
winodw_size:int, k:int, gamma:float, alpha:float, sigma: float, tau:float, r:int) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
classnames = [c.strip() for c in classnames.split(",")]
num_classes = len(classnames)
winodw_size = (winodw_size, winodw_size)
stride = (winodw_size[0] // 2, winodw_size[1] // 2)
preds = lposs(maskclip, dino, img_tensor, classnames, window_size=winodw_size, window_stride=stride, sigma=1/sigma, lp_k_image=k, lp_gamma=gamma, lp_alpha=alpha)
if use_lposs_plus:
preds = lposs_plus(img_tensor, preds, tau=tau, alpha=alpha, r=r)
preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
preds = F.softmax(preds * 100, dim=1).cpu().numpy()
return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
with gr.Blocks() as demo:
gr.Markdown("# LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation")
gr.Markdown("""<div align='center' style='margin: 1em 0;'>
<a href='http://arxiv.org/abs/2503.19777' target='_blank' style='margin-right: 2em; text-decoration: none; font-weight: bold;'>
📄 arXiv
</a>
<a href='https://github.com/vladan-stojnic/LPOSS' target='_blank' style='text-decoration: none; font-weight: bold;'>
💻 GitHub
</a>
</div>""")
gr.Markdown("Upload an image and specify the objects you want to segment by listing their names separated by commas.")
with gr.Accordion("Hyper-parameters", open=False):
with gr.Column(scale=1):
# with gr.Row():
# gr.Markdown("Hyper-parameters")
with gr.Row():
window_size = gr.Slider(minimum=112, maximum=448, value=DEFAULT_WSIZE, step=16, label="Window Size")
k = gr.Slider(minimum=50, maximum=800, value=DEFAULT_K, step=50, label="k (LPOSS number of graph neighbors)")
gamma = gr.Slider(minimum=0.0, maximum=10.0, value=DEFAULT_GAMMA, step=0.5, label="γ (LPOSS graph edge tuning)")
sigma = gr.Slider(minimum=50, maximum=400, value=DEFAULT_SIGMA, step=10, label="σ (LPOSS spatial affinity tuning)")
tau = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_TAU, step=0.01, label="τ (LPOSS+ appearance affinity tuning)")
r = gr.Slider(minimum=3, maximum=15, value=DEFAULT_R, step=2, label="r (LPOSS+ kernel size)")
alpha = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05, label="α (amount of propagation)")
with gr.Row():
reset_btn = gr.Button("Reset to Default Values")
with gr.Row():
class_names = gr.Textbox(label="Class Names", info="Separate class names with commas")
use_lposs_plus = gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
with gr.Row():
segment_btn = gr.Button("Segment Image")
with gr.Row():
with gr.Column(scale=2):
input_image = gr.Image(label="Input Image")
# class_names = gr.Textbox(label="Class Names", info="Separate class names with commas")
# use_lposs_plus = gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
with gr.Column(scale=3):
output_image = gr.AnnotatedImage(label="Segmentation Results")
reset_btn.click(fn=reset_hyperparams, outputs=[window_size, k, gamma, alpha, sigma, tau, r])
segment_btn.click(
fn=segment_image,
inputs=[input_image, class_names, use_lposs_plus, window_size, k, gamma, alpha, sigma, tau, r],
outputs=[output_image]
)
demo.launch() |