coral-ai-demo / app.py
Yohan Runhaar
Replace with pie chart
e91cfe6
raw
history blame
6.27 kB
import gradio as gr
import torch
import os
import json
import requests
import matplotlib.pyplot as plt
from ultralytics import YOLO
# Coral AI model files hosted in the Hugging Face model repository
model_names = [
"yolov8_xlarge_latest.pt"
]
current_model_name = "yolov8_xlarge_latest.pt"
model_dir = "models"
examples_dir = "examples"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(examples_dir, exist_ok=True)
# Download models if not already present locally
HF_MODELS_REPO = "https://huggingface.co/reefsupport/coral-ai/resolve/main/models"
for model_name in model_names:
model_path = os.path.join(model_dir, model_name)
if not os.path.exists(model_path):
print(f"Downloading {model_name}...")
model_url = f"{HF_MODELS_REPO}/{model_name}"
torch.hub.download_url_to_file(model_url, model_path)
# Load the initial model
model = YOLO(os.path.join(model_dir, current_model_name))
def download_examples(json_file="examples.json"):
"""
Download example images specified in the JSON file and save them in the examples directory.
"""
if not os.path.exists(json_file):
print(f"Examples JSON file '{json_file}' not found. Skipping example download.")
return
with open(json_file, "r") as f:
data = json.load(f)
for example in data["examples"]:
image_name = example["name"].replace(" ", "_").lower() + ".jpg"
image_path = os.path.join(examples_dir, image_name)
if not os.path.exists(image_path): # Skip if already downloaded
print(f"Downloading {example['name']}...")
response = requests.get(example["url"], stream=True)
if response.status_code == 200:
with open(image_path, "wb") as img_file:
for chunk in response.iter_content(1024):
img_file.write(chunk)
print(f"Saved {example['name']} to {image_path}")
else:
print(f"Failed to download {example['name']}: {response.status_code}")
# Download example images
download_examples()
def compute_class_areas(predictions, image_shape):
"""
Compute the area percentage covered by each class for the prediction.
"""
total_pixels = image_shape[0] * image_shape[1] # Total pixels in the image
class_areas = {}
for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls):
class_name = predictions.names[int(cls_id.item())] # Get the class name
mask_area = mask.sum().item() # Count non-zero pixels in the mask
if class_name not in class_areas:
class_areas[class_name] = 0
class_areas[class_name] += mask_area
# Convert areas to percentages relative to the total image
for class_name in class_areas:
class_areas[class_name] = (class_areas[class_name] / total_pixels) * 100
return class_areas
def generate_coverage_pie_chart(class_areas):
"""
Generates a pie chart for class coverage percentages.
"""
total_percentage = sum(class_areas.values())
other_percentage = max(0, 100 - total_percentage)
labels = list(class_areas.keys()) + (["Other"] if other_percentage > 0 else [])
sizes = list(class_areas.values()) + ([other_percentage] if other_percentage > 0 else [])
plt.figure(figsize=(8, 6))
plt.pie(
sizes,
labels=[f"{label} ({size:.1f}%)" for label, size in zip(labels, sizes)],
autopct="%1.1f%%",
startangle=90,
colors=plt.cm.tab20.colors
)
plt.axis("equal")
plt.title("Class Coverage Distribution")
chart_path = "class_coverage_pie.png"
plt.savefig(chart_path)
plt.close()
return chart_path
def coral_ai_inference(image, model_name):
"""
Perform inference and generate class coverage data.
"""
global model
global current_model_name
if model_name != current_model_name:
model = YOLO(os.path.join(model_dir, model_name))
current_model_name = model_name
results = model(image) # Perform inference
# Generate class coverage
class_areas = compute_class_areas(results[0], image.shape[:2])
pie_chart_path = generate_coverage_pie_chart(class_areas)
# Render the prediction
rendered_image = results[0].plot()
return rendered_image, pie_chart_path
# Dynamically generate Gradio examples
def generate_examples(json_file="examples.json"):
"""
Generate Gradio examples from the examples.json file.
"""
if not os.path.exists(json_file):
return []
with open(json_file, "r") as f:
data = json.load(f)
return [[os.path.join(examples_dir, example["name"].replace(" ", "_").lower() + ".jpg"), current_model_name] for example in data["examples"]]
# Define Gradio interface
inputs = [
gr.Image(type="numpy", label="Input Image"),
gr.Dropdown(model_names, value=current_model_name, label="Model Type"),
]
outputs = [
gr.Image(type="numpy", label="Segmented Image"),
gr.Image(type="filepath", label="Class Coverage Pie Chart"),
]
# Icons with links
icons = """
<a href="https://reef.support" target="_blank" style="margin-right: 20px;">
<img src="https://img.icons8.com/ios-filled/50/000000/home.png" alt="Home" style="width: 30px; height: 30px;">
</a>
<a href="https://docs.reef.support" target="_blank" style="margin-right: 20px;">
<img src="https://img.icons8.com/ios-filled/50/000000/document.png" alt="Documentation" style="width: 30px; height: 30px;">
</a>
<a href="https://github.com/reefsupport" target="_blank" style="margin-right: 20px;">
<img src="https://img.icons8.com/ios-filled/50/000000/github.png" alt="GitHub" style="width: 30px; height: 30px;">
</a>
<a href="https://url.reef.support/join-discord" target="_blank">
<img src="https://img.icons8.com/ios-filled/50/000000/discord-logo.png" alt="Discord" style="width: 30px; height: 30px;">
</a>
"""
title = "Coral AI Demo"
examples = generate_examples()
demo_app = gr.Blocks()
with gr.Blocks() as demo:
gr.Markdown(icons)
gr.Interface(
fn=coral_ai_inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
cache_examples=True,
)
demo.queue().launch(debug=True, share=True)