File size: 2,403 Bytes
13170b5
 
b091d1e
13170b5
 
b091d1e
13170b5
 
b091d1e
13170b5
 
 
b091d1e
13170b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b24456a
 
 
583ae71
 
 
 
 
 
 
 
13170b5
 
 
 
 
 
b091d1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13170b5
b091d1e
 
13170b5
 
 
 
b091d1e
13170b5
 
 
b091d1e
 
 
13170b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Code inspired from ultralyrics example with gradio
import gradio as gradio
import PIL.Image as Image
import os
import shutil

from ultralytics import YOLO
from huggingface_hub import hf_hub_download

# Directory where downloaded model will be stored
MODEL_DIR = "cached_models"
os.makedirs(MODEL_DIR, exist_ok=True)

# List of models available in the gradio ui
AVAILABLE_MODELS = {
    "YOLOv8m Speech Bubble (kitsumed)": {
        "repo_id": "kitsumed/yolov8m_seg-speech-bubble",
        # Filename, include sub-directory if model not at root (models/v1/model.pt)
        "filename": "model.pt"
    },
    # Add more models here
}

# Cache for currently loaded model
current_model = None
current_model_name = None


def load_model(model_name):
    global current_model, current_model_name

    if model_name == current_model_name:
        return current_model

    # Load the repo info related to the selected model from the available models dictionary
    info = AVAILABLE_MODELS.get(model_name)
    
    model_path = hf_hub_download(
        repo_id=info["repo_id"],
        filename=info["filename"],
        # Where to cache the downloaded file, files already cached will directly be reused
        local_dir=MODEL_DIR
    )

    current_model = YOLO(model_path)
    current_model_name = model_name
    return current_model


def predict_image(img, conf_threshold, iou_threshold, model_name):
    model = load_model(model_name)

    results = model.predict(
        source=img,
        conf=conf_threshold,
        iou=iou_threshold,
        show_labels=True,
        show_conf=True,
        imgsz=640,
    )

    for r in results:
        im_array = r.plot()
        im = Image.fromarray(im_array[..., ::-1])

    return im


iface = gradio.Interface(
    fn=predict_image,
    inputs=[
        gradio.Image(type="pil", label="Upload Image"),
        gradio.Slider(minimum=0, maximum=1, value=0.20, label="Confidence threshold"),
        gradio.Slider(minimum=0, maximum=1, value=0.40, label="IoU threshold"),
        gradio.Dropdown(choices=list(AVAILABLE_MODELS.keys()), label="Select Model", value=list(AVAILABLE_MODELS.keys())[0])
    ],
    outputs=gradio.Image(type="pil", label="Result"),
    title="Try out kitsumed YOLO models",
    description="Select a model from kitsumed on Hugging Face and upload an image to perform predictions.",
)

if __name__ == "__main__":
    iface.launch()