import gradio as gr
import torch
import os
import json
import requests
import matplotlib.pyplot as plt
from ultralytics import YOLO
# Coral AI model files hosted in the Hugging Face model repository
model_names = [
"yolov8_xlarge_latest.pt"
]
current_model_name = "yolov8_xlarge_latest.pt"
model_dir = "models"
examples_dir = "examples"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(examples_dir, exist_ok=True)
# Download models if not already present locally
HF_MODELS_REPO = "https://huggingface.co/reefsupport/coral-ai/resolve/main/models"
for model_name in model_names:
model_path = os.path.join(model_dir, model_name)
if not os.path.exists(model_path):
print(f"Downloading {model_name}...")
model_url = f"{HF_MODELS_REPO}/{model_name}"
torch.hub.download_url_to_file(model_url, model_path)
# Load the initial model
model = YOLO(os.path.join(model_dir, current_model_name))
def download_examples(json_file="examples.json"):
"""
Download example images specified in the JSON file and save them in the examples directory.
"""
if not os.path.exists(json_file):
print(f"Examples JSON file '{json_file}' not found. Skipping example download.")
return
with open(json_file, "r") as f:
data = json.load(f)
for example in data["examples"]:
image_name = example["name"].replace(" ", "_").lower() + ".jpg"
image_path = os.path.join(examples_dir, image_name)
if not os.path.exists(image_path): # Skip if already downloaded
print(f"Downloading {example['name']}...")
response = requests.get(example["url"], stream=True)
if response.status_code == 200:
with open(image_path, "wb") as img_file:
for chunk in response.iter_content(1024):
img_file.write(chunk)
print(f"Saved {example['name']} to {image_path}")
else:
print(f"Failed to download {example['name']}: {response.status_code}")
# Download example images
download_examples()
def compute_class_areas(predictions, image_shape):
"""
Compute the area percentage covered by each class for the prediction.
"""
total_pixels = image_shape[0] * image_shape[1] # Total pixels in the image
class_areas = {}
for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls):
class_name = predictions.names[int(cls_id.item())] # Get the class name
mask_area = mask.sum().item() # Count non-zero pixels in the mask
if class_name not in class_areas:
class_areas[class_name] = 0
class_areas[class_name] += mask_area
# Convert areas to percentages relative to the total image
for class_name in class_areas:
class_areas[class_name] = (class_areas[class_name] / total_pixels) * 100
return class_areas
def generate_coverage_pie_chart(class_areas):
"""
Generates a pie chart for class coverage percentages.
"""
total_percentage = sum(class_areas.values())
other_percentage = max(0, 100 - total_percentage)
labels = list(class_areas.keys()) + (["Other"] if other_percentage > 0 else [])
sizes = list(class_areas.values()) + ([other_percentage] if other_percentage > 0 else [])
plt.figure(figsize=(8, 6))
plt.pie(
sizes,
labels=[f"{label} ({size:.1f}%)" for label, size in zip(labels, sizes)],
autopct="%1.1f%%",
startangle=90,
colors=plt.cm.tab20.colors
)
plt.axis("equal")
plt.title("Class Coverage Distribution")
chart_path = "class_coverage_pie.png"
plt.savefig(chart_path)
plt.close()
return chart_path
def coral_ai_inference(image, model_name):
"""
Perform inference and generate class coverage data.
"""
global model
global current_model_name
if model_name != current_model_name:
model = YOLO(os.path.join(model_dir, model_name))
current_model_name = model_name
results = model(image) # Perform inference
# Generate class coverage
class_areas = compute_class_areas(results[0], image.shape[:2])
pie_chart_path = generate_coverage_pie_chart(class_areas)
# Render the prediction
rendered_image = results[0].plot()
return rendered_image, pie_chart_path
# Dynamically generate Gradio examples
def generate_examples(json_file="examples.json"):
"""
Generate Gradio examples from the examples.json file.
"""
if not os.path.exists(json_file):
return []
with open(json_file, "r") as f:
data = json.load(f)
return [[os.path.join(examples_dir, example["name"].replace(" ", "_").lower() + ".jpg"), current_model_name] for example in data["examples"]]
# Define Gradio interface
inputs = [
gr.Image(type="numpy", label="Input Image"),
gr.Dropdown(model_names, value=current_model_name, label="Model Type"),
]
outputs = [
gr.Image(type="numpy", label="Segmented Image"),
gr.Image(type="filepath", label="Class Coverage Pie Chart"),
]
# Icons with links
icons = """
"""
title = "Coral AI Demo"
examples = generate_examples()
demo_app = gr.Blocks()
with gr.Blocks() as demo:
gr.Markdown(icons)
gr.Interface(
fn=coral_ai_inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
cache_examples=True,
)
demo.queue().launch(debug=True, share=True)