|
from transformers import pipeline, SamModel, SamProcessor |
|
import torch |
|
import numpy as np |
|
import spaces |
|
|
|
checkpoint = "google/owlvit-base-patch16" |
|
detector = pipeline(model=checkpoint, task="zero-shot-object-detection") |
|
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda") |
|
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") |
|
|
|
@spaces.GPU |
|
def query(image, texts, threshold): |
|
texts = texts.split(",") |
|
print(texts) |
|
print(image.size) |
|
predictions = detector( |
|
image, |
|
candidate_labels=texts, |
|
) |
|
print(predictions) |
|
result_labels = [] |
|
for pred in predictions: |
|
|
|
box = pred["box"] |
|
score = pred["score"] |
|
label = pred["label"] |
|
box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2), |
|
round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)] |
|
|
|
inputs = sam_processor( |
|
image, |
|
input_boxes=[[[box]]], |
|
return_tensors="pt" |
|
).to("cuda") |
|
|
|
with torch.no_grad(): |
|
outputs = sam_model(**inputs) |
|
|
|
mask = sam_processor.image_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), |
|
inputs["original_sizes"].cpu(), |
|
inputs["reshaped_input_sizes"].cpu() |
|
)[0][0][0].numpy() |
|
mask = mask[np.newaxis, ...] |
|
result_labels.append((mask, label)) |
|
return image, result_labels |
|
|
|
import gradio as gr |
|
|
|
description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable." |
|
demo = gr.Interface( |
|
query, |
|
inputs=[gr.Image(type="pil"), "text", gr.Slider(0, 1, value=0.2)], |
|
outputs="annotatedimage", |
|
title="OWL π€ SAM", |
|
|
|
examples=[ |
|
["/content/cats.png", "cat", 0.1], |
|
], |
|
) |
|
demo.launch(debug=True) |