|
import os |
|
import sys |
|
from importlib.util import find_spec |
|
|
|
print("Prepare demo ...") |
|
if not os.path.exists("tcl.pth"): |
|
print("Download TCL checkpoint ...") |
|
os.system("wget -q https://github.com/kakaobrain/tcl/releases/download/v1.0.0/tcl.pth") |
|
|
|
if not (find_spec("mmcv") and find_spec("mmseg")): |
|
print("Install mmcv & mmseg ...") |
|
os.system("mim install mmcv-full==1.6.2 mmsegmentation==0.27.0") |
|
|
|
if not find_spec("detectron2"): |
|
print("Install detectron ...") |
|
os.system("pip install git+https://github.com/facebookresearch/detectron2.git") |
|
|
|
sys.path.insert(0, "./tcl/") |
|
|
|
print(" -- done.") |
|
|
|
import json |
|
from contextlib import ExitStack |
|
import gradio as gr |
|
import torch |
|
from torch.cuda.amp import autocast |
|
|
|
from detectron2.evaluation import inference_context |
|
|
|
from predictor import build_demo_model |
|
|
|
|
|
model = build_demo_model() |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
print(f"device: {device}") |
|
model.to(device) |
|
|
|
|
|
title = "TCL: Text-grounded Contrastive Learning" |
|
description_head = """ |
|
<p style='text-align: center'> <a href='https://arxiv.org/abs/2212.00785' target='_blank'>Paper</a> | <a href='https://github.com/kakaobrain/tcl' target='_blank'>Code</a> </p> |
|
""" |
|
|
|
description_body = f""" |
|
Gradio Demo for "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs". |
|
|
|
Explore TCL's capability to perform open-world semantic segmentation **without any mask annotations**. Choose from provided examples or upload your own image. Use the query format `bg; class1; class2; ...`, with `;` as the separator, and the `bg` background query being optional (as in the third example). |
|
|
|
This demo highlights the strengths and limitations of unsupervised open-world segmentation methods. Although TCL can handle arbitrary concepts, accurately capturing object boundaries without mask annotation remains a challenge. |
|
""" |
|
|
|
if device.type == "cpu": |
|
description_body += f"\nInference takes about 10 seconds since this demo is running on the free CPU device." |
|
|
|
description = description_head + description_body |
|
|
|
article = """ |
|
<p style='text-align: center'><a href='https://arxiv.org/abs/2212.00785' target='_blank'>Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs</a> | <a href='https://github.com/kakaobrain/tcl' target='_blank'>Github Repo</a></p> |
|
""" |
|
|
|
voc_examples = [ |
|
["examples/voc_59.jpg", "bg; cat; dog"], |
|
["examples/voc_97.jpg", "bg; car"], |
|
["examples/voc_266.jpg", "bg; dog"], |
|
["examples/voc_294.jpg", "bg; bird"], |
|
["examples/voc_864.jpg", "bg; cat"], |
|
["examples/voc_1029.jpg", "bg; bus"], |
|
] |
|
|
|
examples = [ |
|
[ |
|
"examples/dogs.jpg", |
|
"bg; corgi; shepherd", |
|
], |
|
[ |
|
"examples/dogs.jpg", |
|
"bg; dog", |
|
], |
|
[ |
|
"examples/dogs.jpg", |
|
"corgi; shepherd; lawn, trees, and fallen leaves", |
|
], |
|
[ |
|
"examples/banana.jpg", |
|
"bg; banana", |
|
], |
|
[ |
|
"examples/banana.jpg", |
|
"bg; red banana; green banana; yellow banana", |
|
], |
|
[ |
|
"examples/frodo_sam_gollum.jpg", |
|
"bg; frodo; gollum; samwise", |
|
], |
|
[ |
|
"examples/frodo_sam_gollum.jpg", |
|
"bg; rocks; monster; boys with cape" |
|
], |
|
[ |
|
"examples/mb_mj.jpg", |
|
"bg; marlon brando; michael jackson", |
|
], |
|
] |
|
|
|
examples = examples + voc_examples |
|
|
|
|
|
def inference(img, query): |
|
query = query.split(";") |
|
query = [v.strip() for v in query] |
|
|
|
with ExitStack() as stack: |
|
stack.enter_context(inference_context(model)) |
|
stack.enter_context(torch.no_grad()) |
|
|
|
with autocast(): |
|
visualized_output = model.forward_vis(img, query) |
|
|
|
return visualized_output |
|
|
|
|
|
theme = gr.themes.Soft(text_size=gr.themes.sizes.text_md, primary_hue="teal") |
|
with gr.Blocks(title=title, theme=theme) as demo: |
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>") |
|
gr.Markdown(description) |
|
input_components = [] |
|
output_components = [] |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4, variant="panel"): |
|
output_image_gr = gr.outputs.Image(label="Segmentation", type="pil").style(height=300) |
|
output_components.append(output_image_gr) |
|
|
|
with gr.Row(): |
|
input_gr = gr.inputs.Image(type="pil") |
|
query_gr = gr.inputs.Textbox(default="", label="Query") |
|
input_components.extend([input_gr, query_gr]) |
|
|
|
with gr.Row(): |
|
clear_btn = gr.Button("Clear") |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
|
inputs = [c for c in input_components if not isinstance(c, gr.State)] |
|
outputs = [c for c in output_components if not isinstance(c, gr.State)] |
|
with gr.Column(scale=2): |
|
examples_handler = gr.Examples( |
|
examples=examples, |
|
inputs=inputs, |
|
outputs=outputs, |
|
fn=inference, |
|
cache_examples=True, |
|
examples_per_page=7, |
|
) |
|
|
|
gr.Markdown(article) |
|
|
|
submit_btn.click( |
|
inference, |
|
input_components, |
|
output_components, |
|
scroll_to_output=True, |
|
) |
|
|
|
clear_btn.click( |
|
None, |
|
[], |
|
(input_components + output_components), |
|
_js=f"""() => {json.dumps( |
|
[component.cleared_value if hasattr(component, "cleared_value") else None |
|
for component in input_components + output_components] + ( |
|
[gr.Column.update(visible=True)] |
|
) |
|
+ ([gr.Column.update(visible=False)]) |
|
)} |
|
""", |
|
) |
|
|
|
demo.launch() |
|
|
|
|