File size: 3,338 Bytes
0340c57
 
 
c27b4ef
548fbd0
0340c57
548fbd0
 
0340c57
 
548fbd0
0340c57
 
548fbd0
 
 
 
0340c57
548fbd0
 
 
 
0340c57
548fbd0
 
 
0340c57
548fbd0
 
0340c57
 
c27b4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548fbd0
 
c27b4ef
548fbd0
 
 
 
c27b4ef
548fbd0
 
 
 
 
 
0340c57
548fbd0
 
0340c57
c27b4ef
 
 
 
 
 
0340c57
 
548fbd0
 
 
 
 
 
 
 
 
0340c57
c27b4ef
 
 
 
 
548fbd0
0340c57
548fbd0
 
 
 
 
0340c57
548fbd0
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
import os
import matplotlib.pyplot as plt
from ultralytics import YOLO

# Coral AI model files hosted in the Hugging Face model repository
model_names = [
    "yolov8_xlarge_latest.pt",
    "yolov8_xlarge_v1.pt",
    "yolov8_xlarge_v2.pt",
]

# Set the initial model
current_model_name = "yolov8_xlarge_latest.pt"
model_dir = "models"
os.makedirs(model_dir, exist_ok=True)

# Download models if not already present locally
HF_MODELS_REPO = "https://huggingface.co/reefsupport/coral-ai/resolve/main/models"
for model_name in model_names:
    model_path = os.path.join(model_dir, model_name)
    if not os.path.exists(model_path):
        print(f"Downloading {model_name}...")
        model_url = f"{HF_MODELS_REPO}/{model_name}"
        torch.hub.download_url_to_file(model_url, model_path)

# Load the initial model
model = YOLO(os.path.join(model_dir, current_model_name))


def generate_coverage_graph(results):
    """
    Generates a coverage percentage graph for the output classes.
    """
    # Extract segmentation masks and class labels
    masks = results[0].masks.data  # Segmentation masks
    class_ids = results[0].masks.cls  # Class IDs

    # Calculate coverage percentages for each class
    total_pixels = masks.sum()
    class_percentages = {int(cls): (mask.sum() / total_pixels) * 100 for cls, mask in zip(class_ids, masks)}

    # Generate the graph
    plt.figure(figsize=(8, 6))
    plt.bar(class_percentages.keys(), class_percentages.values())
    plt.xlabel("Class ID")
    plt.ylabel("Coverage Percentage")
    plt.title("Class Coverage in Segmentation")
    plt.xticks(list(class_percentages.keys()))
    graph_path = "class_coverage.png"
    plt.savefig(graph_path)
    plt.close()
    return graph_path


def coral_ai_inference(image: str, model_name: str):
    """
    Coral AI inference function with class coverage graph.
    Args:
        image: Input image filepath
        model_name: Name of the model
    Returns:
        Rendered image and class coverage graph
    """
    global model
    global current_model_name
    if model_name != current_model_name:
        model = YOLO(os.path.join(model_dir, model_name))
        current_model_name = model_name

    # Perform inference
    results = model.predict(image, return_outputs=True)

    # Rendered image
    rendered_image = results[0].plot()

    # Generate class coverage graph
    graph_path = generate_coverage_graph(results)
    return rendered_image, graph_path


# Define Gradio inputs and outputs
inputs = [
    gr.Image(type="filepath", label="Input Image"),
    gr.Dropdown(
        model_names,
        value=current_model_name,
        label="Model Type",
    ),
]

outputs = [
    gr.Image(type="filepath", label="Segmented Image"),
    gr.Image(type="filepath", label="Class Coverage Graph"),
]

title = "Coral AI YOLOv8 Segmentation Demo"

examples = [
    ["examples/coral_image1.jpg", "yolov8_xlarge_latest.pt"],
    ["examples/coral_image2.jpg", "yolov8_xlarge_latest.pt"],
    ["examples/coral_image3.jpg", "yolov8_xlarge_latest.pt"],
]

# Create and launch the Gradio interface
demo_app = gr.Interface(
    fn=coral_ai_inference,
    inputs=inputs,
    outputs=outputs,
    title=title,
    examples=examples,
    cache_examples=True,
    theme="default",
)
demo_app.queue().launch(debug=True)