coral-ai-demo / app.py
Yohan Runhaar
fix gradio
c27b4ef
raw
history blame
3.34 kB
import gradio as gr
import torch
import os
import matplotlib.pyplot as plt
from ultralytics import YOLO
# Coral AI model files hosted in the Hugging Face model repository
model_names = [
"yolov8_xlarge_latest.pt",
"yolov8_xlarge_v1.pt",
"yolov8_xlarge_v2.pt",
]
# Set the initial model
current_model_name = "yolov8_xlarge_latest.pt"
model_dir = "models"
os.makedirs(model_dir, exist_ok=True)
# Download models if not already present locally
HF_MODELS_REPO = "https://huggingface.co/reefsupport/coral-ai/resolve/main/models"
for model_name in model_names:
model_path = os.path.join(model_dir, model_name)
if not os.path.exists(model_path):
print(f"Downloading {model_name}...")
model_url = f"{HF_MODELS_REPO}/{model_name}"
torch.hub.download_url_to_file(model_url, model_path)
# Load the initial model
model = YOLO(os.path.join(model_dir, current_model_name))
def generate_coverage_graph(results):
"""
Generates a coverage percentage graph for the output classes.
"""
# Extract segmentation masks and class labels
masks = results[0].masks.data # Segmentation masks
class_ids = results[0].masks.cls # Class IDs
# Calculate coverage percentages for each class
total_pixels = masks.sum()
class_percentages = {int(cls): (mask.sum() / total_pixels) * 100 for cls, mask in zip(class_ids, masks)}
# Generate the graph
plt.figure(figsize=(8, 6))
plt.bar(class_percentages.keys(), class_percentages.values())
plt.xlabel("Class ID")
plt.ylabel("Coverage Percentage")
plt.title("Class Coverage in Segmentation")
plt.xticks(list(class_percentages.keys()))
graph_path = "class_coverage.png"
plt.savefig(graph_path)
plt.close()
return graph_path
def coral_ai_inference(image: str, model_name: str):
"""
Coral AI inference function with class coverage graph.
Args:
image: Input image filepath
model_name: Name of the model
Returns:
Rendered image and class coverage graph
"""
global model
global current_model_name
if model_name != current_model_name:
model = YOLO(os.path.join(model_dir, model_name))
current_model_name = model_name
# Perform inference
results = model.predict(image, return_outputs=True)
# Rendered image
rendered_image = results[0].plot()
# Generate class coverage graph
graph_path = generate_coverage_graph(results)
return rendered_image, graph_path
# Define Gradio inputs and outputs
inputs = [
gr.Image(type="filepath", label="Input Image"),
gr.Dropdown(
model_names,
value=current_model_name,
label="Model Type",
),
]
outputs = [
gr.Image(type="filepath", label="Segmented Image"),
gr.Image(type="filepath", label="Class Coverage Graph"),
]
title = "Coral AI YOLOv8 Segmentation Demo"
examples = [
["examples/coral_image1.jpg", "yolov8_xlarge_latest.pt"],
["examples/coral_image2.jpg", "yolov8_xlarge_latest.pt"],
["examples/coral_image3.jpg", "yolov8_xlarge_latest.pt"],
]
# Create and launch the Gradio interface
demo_app = gr.Interface(
fn=coral_ai_inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
cache_examples=True,
theme="default",
)
demo_app.queue().launch(debug=True)