Spaces:
Runtime error
Runtime error
File size: 6,266 Bytes
0340c57 56ae707 c27b4ef 548fbd0 0340c57 548fbd0 4222ab1 0340c57 548fbd0 56ae707 548fbd0 56ae707 0340c57 548fbd0 0340c57 548fbd0 0340c57 548fbd0 0340c57 56ae707 4222ab1 c27b4ef 4222ab1 c27b4ef cccde55 4222ab1 c27b4ef 39a561f cccde55 39a561f cccde55 39a561f cccde55 4222ab1 e91cfe6 4222ab1 e91cfe6 4222ab1 e91cfe6 c27b4ef e91cfe6 c27b4ef e91cfe6 c27b4ef 4222ab1 548fbd0 4222ab1 548fbd0 0340c57 4222ab1 e91cfe6 0340c57 4222ab1 c27b4ef e91cfe6 0340c57 56ae707 4222ab1 548fbd0 4222ab1 548fbd0 0340c57 c27b4ef 4222ab1 e91cfe6 c27b4ef d5cc629 39a561f 0340c57 56ae707 0340c57 d5cc629 1570cf6 d5cc629 1570cf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import gradio as gr
import torch
import os
import json
import requests
import matplotlib.pyplot as plt
from ultralytics import YOLO
# Coral AI model files hosted in the Hugging Face model repository
model_names = [
"yolov8_xlarge_latest.pt"
]
current_model_name = "yolov8_xlarge_latest.pt"
model_dir = "models"
examples_dir = "examples"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(examples_dir, exist_ok=True)
# Download models if not already present locally
HF_MODELS_REPO = "https://huggingface.co/reefsupport/coral-ai/resolve/main/models"
for model_name in model_names:
model_path = os.path.join(model_dir, model_name)
if not os.path.exists(model_path):
print(f"Downloading {model_name}...")
model_url = f"{HF_MODELS_REPO}/{model_name}"
torch.hub.download_url_to_file(model_url, model_path)
# Load the initial model
model = YOLO(os.path.join(model_dir, current_model_name))
def download_examples(json_file="examples.json"):
"""
Download example images specified in the JSON file and save them in the examples directory.
"""
if not os.path.exists(json_file):
print(f"Examples JSON file '{json_file}' not found. Skipping example download.")
return
with open(json_file, "r") as f:
data = json.load(f)
for example in data["examples"]:
image_name = example["name"].replace(" ", "_").lower() + ".jpg"
image_path = os.path.join(examples_dir, image_name)
if not os.path.exists(image_path): # Skip if already downloaded
print(f"Downloading {example['name']}...")
response = requests.get(example["url"], stream=True)
if response.status_code == 200:
with open(image_path, "wb") as img_file:
for chunk in response.iter_content(1024):
img_file.write(chunk)
print(f"Saved {example['name']} to {image_path}")
else:
print(f"Failed to download {example['name']}: {response.status_code}")
# Download example images
download_examples()
def compute_class_areas(predictions, image_shape):
"""
Compute the area percentage covered by each class for the prediction.
"""
total_pixels = image_shape[0] * image_shape[1] # Total pixels in the image
class_areas = {}
for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls):
class_name = predictions.names[int(cls_id.item())] # Get the class name
mask_area = mask.sum().item() # Count non-zero pixels in the mask
if class_name not in class_areas:
class_areas[class_name] = 0
class_areas[class_name] += mask_area
# Convert areas to percentages relative to the total image
for class_name in class_areas:
class_areas[class_name] = (class_areas[class_name] / total_pixels) * 100
return class_areas
def generate_coverage_pie_chart(class_areas):
"""
Generates a pie chart for class coverage percentages.
"""
total_percentage = sum(class_areas.values())
other_percentage = max(0, 100 - total_percentage)
labels = list(class_areas.keys()) + (["Other"] if other_percentage > 0 else [])
sizes = list(class_areas.values()) + ([other_percentage] if other_percentage > 0 else [])
plt.figure(figsize=(8, 6))
plt.pie(
sizes,
labels=[f"{label} ({size:.1f}%)" for label, size in zip(labels, sizes)],
autopct="%1.1f%%",
startangle=90,
colors=plt.cm.tab20.colors
)
plt.axis("equal")
plt.title("Class Coverage Distribution")
chart_path = "class_coverage_pie.png"
plt.savefig(chart_path)
plt.close()
return chart_path
def coral_ai_inference(image, model_name):
"""
Perform inference and generate class coverage data.
"""
global model
global current_model_name
if model_name != current_model_name:
model = YOLO(os.path.join(model_dir, model_name))
current_model_name = model_name
results = model(image) # Perform inference
# Generate class coverage
class_areas = compute_class_areas(results[0], image.shape[:2])
pie_chart_path = generate_coverage_pie_chart(class_areas)
# Render the prediction
rendered_image = results[0].plot()
return rendered_image, pie_chart_path
# Dynamically generate Gradio examples
def generate_examples(json_file="examples.json"):
"""
Generate Gradio examples from the examples.json file.
"""
if not os.path.exists(json_file):
return []
with open(json_file, "r") as f:
data = json.load(f)
return [[os.path.join(examples_dir, example["name"].replace(" ", "_").lower() + ".jpg"), current_model_name] for example in data["examples"]]
# Define Gradio interface
inputs = [
gr.Image(type="numpy", label="Input Image"),
gr.Dropdown(model_names, value=current_model_name, label="Model Type"),
]
outputs = [
gr.Image(type="numpy", label="Segmented Image"),
gr.Image(type="filepath", label="Class Coverage Pie Chart"),
]
# Icons with links
icons = """
<a href="https://reef.support" target="_blank" style="margin-right: 20px;">
<img src="https://img.icons8.com/ios-filled/50/000000/home.png" alt="Home" style="width: 30px; height: 30px;">
</a>
<a href="https://docs.reef.support" target="_blank" style="margin-right: 20px;">
<img src="https://img.icons8.com/ios-filled/50/000000/document.png" alt="Documentation" style="width: 30px; height: 30px;">
</a>
<a href="https://github.com/reefsupport" target="_blank" style="margin-right: 20px;">
<img src="https://img.icons8.com/ios-filled/50/000000/github.png" alt="GitHub" style="width: 30px; height: 30px;">
</a>
<a href="https://url.reef.support/join-discord" target="_blank">
<img src="https://img.icons8.com/ios-filled/50/000000/discord-logo.png" alt="Discord" style="width: 30px; height: 30px;">
</a>
"""
title = "Coral AI Demo"
examples = generate_examples()
demo_app = gr.Blocks()
with gr.Blocks() as demo:
gr.Markdown(icons)
gr.Interface(
fn=coral_ai_inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
cache_examples=True,
)
demo.queue().launch(debug=True, share=True) |