File size: 2,198 Bytes
8a9739b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from PIL import Image, ImageDraw
import gradio as gr

# Load pre-trained Owl-ViT model and processor
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")

def detect_objects(image: Image.Image, texts: str):
    # Prepare text queries
    text_queries = [text.strip() for text in texts.split(',')]
    
    # Prepare inputs for the model
    inputs = processor(text=text_queries, images=image, return_tensors="pt")

    # Perform inference with the model
    with torch.no_grad():
        outputs = model(**inputs)

    # Post-process the outputs to extract detected boxes and labels
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process(outputs=outputs, target_sizes=target_sizes)

    # Extracting results
    detected_boxes = []
    for i, box in enumerate(results[0]["boxes"]):
        score = results[0]["scores"][i].item()
        label = results[0]["labels"][i].item()
        if score > 0.1:  # Confidence threshold
            detected_boxes.append((box, text_queries[label], score))

    return detected_boxes

def visualize(image, texts):
    # Detect objects in the image
    boxes = detect_objects(image, texts)

    # Draw boxes on the image
    image = image.copy()
    draw = ImageDraw.Draw(image)
    for box, label, score in boxes:
        box = [round(coord) for coord in box.tolist()]
        draw.rectangle(box, outline="red", width=3)
        draw.text((box[0], box[1]), f"{label}: {score:.2f}", fill="red")

    return image

# Gradio Interface
def gradio_interface(image, texts):
    return visualize(image, texts)

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[gr.Image(type="pil", label="Upload an Image"), gr.Textbox(label="Comma-separated Text Queries")],
    outputs=gr.Image(type="pil", label="Object Detection Output"),
    title="Owl-ViT Object Detection",
    description="Upload an image and provide comma-separated text queries for object detection.",
    allow_flagging="never"
)

if __name__ == "__main__":
    interface.launch()