File size: 6,266 Bytes
0340c57
 
 
56ae707
 
c27b4ef
548fbd0
0340c57
548fbd0
 
4222ab1
0340c57
 
548fbd0
 
56ae707
548fbd0
56ae707
0340c57
548fbd0
 
 
 
0340c57
548fbd0
 
 
0340c57
548fbd0
 
0340c57
 
56ae707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4222ab1
c27b4ef
4222ab1
c27b4ef
cccde55
4222ab1
c27b4ef
39a561f
cccde55
 
39a561f
cccde55
 
 
39a561f
cccde55
 
 
4222ab1
 
 
 
e91cfe6
4222ab1
e91cfe6
4222ab1
e91cfe6
 
 
 
 
 
c27b4ef
e91cfe6
 
 
 
 
 
 
 
 
 
 
c27b4ef
e91cfe6
c27b4ef
 
4222ab1
548fbd0
4222ab1
548fbd0
 
 
 
 
 
0340c57
4222ab1
 
 
 
e91cfe6
0340c57
4222ab1
c27b4ef
 
e91cfe6
0340c57
 
56ae707
 
 
 
 
 
 
 
 
 
 
 
4222ab1
548fbd0
4222ab1
 
548fbd0
0340c57
c27b4ef
4222ab1
e91cfe6
c27b4ef
 
d5cc629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39a561f
0340c57
56ae707
0340c57
d5cc629
 
1570cf6
d5cc629
 
 
 
 
 
 
 
 
 
1570cf6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import gradio as gr
import torch
import os
import json
import requests
import matplotlib.pyplot as plt
from ultralytics import YOLO

# Coral AI model files hosted in the Hugging Face model repository
model_names = [
    "yolov8_xlarge_latest.pt"
]

current_model_name = "yolov8_xlarge_latest.pt"
model_dir = "models"
examples_dir = "examples"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(examples_dir, exist_ok=True)

# Download models if not already present locally
HF_MODELS_REPO = "https://huggingface.co/reefsupport/coral-ai/resolve/main/models"
for model_name in model_names:
    model_path = os.path.join(model_dir, model_name)
    if not os.path.exists(model_path):
        print(f"Downloading {model_name}...")
        model_url = f"{HF_MODELS_REPO}/{model_name}"
        torch.hub.download_url_to_file(model_url, model_path)

# Load the initial model
model = YOLO(os.path.join(model_dir, current_model_name))


def download_examples(json_file="examples.json"):
    """
    Download example images specified in the JSON file and save them in the examples directory.
    """
    if not os.path.exists(json_file):
        print(f"Examples JSON file '{json_file}' not found. Skipping example download.")
        return

    with open(json_file, "r") as f:
        data = json.load(f)

    for example in data["examples"]:
        image_name = example["name"].replace(" ", "_").lower() + ".jpg"
        image_path = os.path.join(examples_dir, image_name)

        if not os.path.exists(image_path):  # Skip if already downloaded
            print(f"Downloading {example['name']}...")
            response = requests.get(example["url"], stream=True)
            if response.status_code == 200:
                with open(image_path, "wb") as img_file:
                    for chunk in response.iter_content(1024):
                        img_file.write(chunk)
                print(f"Saved {example['name']} to {image_path}")
            else:
                print(f"Failed to download {example['name']}: {response.status_code}")


# Download example images
download_examples()


def compute_class_areas(predictions, image_shape):
    """
    Compute the area percentage covered by each class for the prediction.
    """
    total_pixels = image_shape[0] * image_shape[1]  # Total pixels in the image
    class_areas = {}

    for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls):
        class_name = predictions.names[int(cls_id.item())]  # Get the class name
        mask_area = mask.sum().item()  # Count non-zero pixels in the mask

        if class_name not in class_areas:
            class_areas[class_name] = 0
        class_areas[class_name] += mask_area

    # Convert areas to percentages relative to the total image
    for class_name in class_areas:
        class_areas[class_name] = (class_areas[class_name] / total_pixels) * 100

    return class_areas


def generate_coverage_pie_chart(class_areas):
    """
    Generates a pie chart for class coverage percentages.
    """
    total_percentage = sum(class_areas.values())
    other_percentage = max(0, 100 - total_percentage)

    labels = list(class_areas.keys()) + (["Other"] if other_percentage > 0 else [])
    sizes = list(class_areas.values()) + ([other_percentage] if other_percentage > 0 else [])

    plt.figure(figsize=(8, 6))
    plt.pie(
        sizes,
        labels=[f"{label} ({size:.1f}%)" for label, size in zip(labels, sizes)],
        autopct="%1.1f%%",
        startangle=90,
        colors=plt.cm.tab20.colors
    )
    plt.axis("equal")
    plt.title("Class Coverage Distribution")
    chart_path = "class_coverage_pie.png"
    plt.savefig(chart_path)
    plt.close()
    return chart_path


def coral_ai_inference(image, model_name):
    """
    Perform inference and generate class coverage data.
    """
    global model
    global current_model_name
    if model_name != current_model_name:
        model = YOLO(os.path.join(model_dir, model_name))
        current_model_name = model_name

    results = model(image)  # Perform inference

    # Generate class coverage
    class_areas = compute_class_areas(results[0], image.shape[:2])
    pie_chart_path = generate_coverage_pie_chart(class_areas)

    # Render the prediction
    rendered_image = results[0].plot()

    return rendered_image, pie_chart_path


# Dynamically generate Gradio examples
def generate_examples(json_file="examples.json"):
    """
    Generate Gradio examples from the examples.json file.
    """
    if not os.path.exists(json_file):
        return []
    with open(json_file, "r") as f:
        data = json.load(f)
    return [[os.path.join(examples_dir, example["name"].replace(" ", "_").lower() + ".jpg"), current_model_name] for example in data["examples"]]


# Define Gradio interface
inputs = [
    gr.Image(type="numpy", label="Input Image"),
    gr.Dropdown(model_names, value=current_model_name, label="Model Type"),
]

outputs = [
    gr.Image(type="numpy", label="Segmented Image"),
    gr.Image(type="filepath", label="Class Coverage Pie Chart"),
]

# Icons with links
icons = """
<a href="https://reef.support" target="_blank" style="margin-right: 20px;">
    <img src="https://img.icons8.com/ios-filled/50/000000/home.png" alt="Home" style="width: 30px; height: 30px;">
</a>
<a href="https://docs.reef.support" target="_blank" style="margin-right: 20px;">
    <img src="https://img.icons8.com/ios-filled/50/000000/document.png" alt="Documentation" style="width: 30px; height: 30px;">
</a>
<a href="https://github.com/reefsupport" target="_blank" style="margin-right: 20px;">
    <img src="https://img.icons8.com/ios-filled/50/000000/github.png" alt="GitHub" style="width: 30px; height: 30px;">
</a>
<a href="https://url.reef.support/join-discord" target="_blank">
    <img src="https://img.icons8.com/ios-filled/50/000000/discord-logo.png" alt="Discord" style="width: 30px; height: 30px;">
</a>
"""

title = "Coral AI Demo"

examples = generate_examples()

demo_app = gr.Blocks()

with gr.Blocks() as demo:
    gr.Markdown(icons)
    gr.Interface(
        fn=coral_ai_inference,
        inputs=inputs,
        outputs=outputs,
        title=title,
        examples=examples,
        cache_examples=True,
    )

demo.queue().launch(debug=True, share=True)