# burrow some code from https://huggingface.co/spaces/xvjiarui/ODISE/tree/main
import os
import sys
from importlib.util import find_spec
print("Prepare demo ...")
if not os.path.exists("tcl.pth"):
print("Download TCL checkpoint ...")
os.system("wget -q https://github.com/kakaobrain/tcl/releases/download/v1.0.0/tcl.pth")
if not (find_spec("mmcv") and find_spec("mmseg")):
print("Install mmcv & mmseg ...")
os.system("mim install mmcv-full==1.6.2 mmsegmentation==0.27.0")
if not find_spec("detectron2"):
print("Install detectron ...")
os.system("pip install git+https://github.com/facebookresearch/detectron2.git")
sys.path.insert(0, "./tcl/")
print(" -- done.")
import json
from contextlib import ExitStack
import gradio as gr
import torch
from detectron2.evaluation import inference_context
from predictor import build_demo_model
model = build_demo_model()
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(f"device: {device}")
model.to(device)
title = "TCL: Text-grounded Contrastive Learning"
title2 = "for Unsupervised Open-world Semantic Segmentation"
title = title + "
" + title2
description_head = """
Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs | Github Repo
""" voc_examples = [ ["examples/voc_59.jpg", "bg; cat; dog"], ["examples/voc_97.jpg", "bg; car"], ["examples/voc_266.jpg", "bg; dog"], ["examples/voc_294.jpg", "bg; bird"], ["examples/voc_864.jpg", "bg; cat"], ["examples/voc_1029.jpg", "bg; bus"], ] examples = [ [ "examples/dogs.jpg", "bg; corgi; shepherd", ], [ "examples/dogs.jpg", "bg; dog", ], [ "examples/dogs.jpg", "corgi; shepherd; lawn, trees, and fallen leaves", ], [ "examples/banana.jpg", "bg; banana", ], [ "examples/banana.jpg", "bg; red banana; green banana; yellow banana", ], [ "examples/frodo_sam_gollum.jpg", "bg; frodo; gollum; samwise", ], [ "examples/frodo_sam_gollum.jpg", "bg; rocks; monster; boys with cape" ], [ "examples/mb_mj.jpg", "bg; marlon brando; michael jackson", ], ] examples = examples + voc_examples def inference(img, query): query = query.split(";") query = [v.strip() for v in query] with ExitStack() as stack: stack.enter_context(inference_context(model)) stack.enter_context(torch.no_grad()) if device.type == "cuda": stack.enter_context(torch.autocast("cuda")) visualized_output = model.forward_vis(img, query) return visualized_output theme = gr.themes.Soft(text_size=gr.themes.sizes.text_md, primary_hue="teal") with gr.Blocks(title=title, theme=theme) as demo: gr.Markdown("