Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import os | |
import matplotlib.pyplot as plt | |
from ultralytics import YOLO | |
# Coral AI model files hosted in the Hugging Face model repository | |
model_names = [ | |
"yolov8_xlarge_latest.pt", | |
"yolov8_xlarge_v1.pt", | |
"yolov8_xlarge_v2.pt", | |
] | |
# Set the initial model | |
current_model_name = "yolov8_xlarge_latest.pt" | |
model_dir = "models" | |
os.makedirs(model_dir, exist_ok=True) | |
# Download models if not already present locally | |
HF_MODELS_REPO = "https://huggingface.co/reefsupport/coral-ai/resolve/main/models" | |
for model_name in model_names: | |
model_path = os.path.join(model_dir, model_name) | |
if not os.path.exists(model_path): | |
print(f"Downloading {model_name}...") | |
model_url = f"{HF_MODELS_REPO}/{model_name}" | |
torch.hub.download_url_to_file(model_url, model_path) | |
# Load the initial model | |
model = YOLO(os.path.join(model_dir, current_model_name)) | |
def generate_coverage_graph(results): | |
""" | |
Generates a coverage percentage graph for the output classes. | |
""" | |
# Extract segmentation masks and class labels | |
masks = results[0].masks.data # Segmentation masks | |
class_ids = results[0].masks.cls # Class IDs | |
# Calculate coverage percentages for each class | |
total_pixels = masks.sum() | |
class_percentages = {int(cls): (mask.sum() / total_pixels) * 100 for cls, mask in zip(class_ids, masks)} | |
# Generate the graph | |
plt.figure(figsize=(8, 6)) | |
plt.bar(class_percentages.keys(), class_percentages.values()) | |
plt.xlabel("Class ID") | |
plt.ylabel("Coverage Percentage") | |
plt.title("Class Coverage in Segmentation") | |
plt.xticks(list(class_percentages.keys())) | |
graph_path = "class_coverage.png" | |
plt.savefig(graph_path) | |
plt.close() | |
return graph_path | |
def coral_ai_inference(image: str, model_name: str): | |
""" | |
Coral AI inference function with class coverage graph. | |
Args: | |
image: Input image filepath | |
model_name: Name of the model | |
Returns: | |
Rendered image and class coverage graph | |
""" | |
global model | |
global current_model_name | |
if model_name != current_model_name: | |
model = YOLO(os.path.join(model_dir, model_name)) | |
current_model_name = model_name | |
# Perform inference | |
results = model.predict(image, return_outputs=True) | |
# Rendered image | |
rendered_image = results[0].plot() | |
# Generate class coverage graph | |
graph_path = generate_coverage_graph(results) | |
return rendered_image, graph_path | |
# Define Gradio inputs and outputs | |
inputs = [ | |
gr.Image(type="filepath", label="Input Image"), | |
gr.Dropdown( | |
model_names, | |
value=current_model_name, | |
label="Model Type", | |
), | |
] | |
outputs = [ | |
gr.Image(type="filepath", label="Segmented Image"), | |
gr.Image(type="filepath", label="Class Coverage Graph"), | |
] | |
title = "Coral AI YOLOv8 Segmentation Demo" | |
examples = [ | |
["examples/coral_image1.jpg", "yolov8_xlarge_latest.pt"], | |
["examples/coral_image2.jpg", "yolov8_xlarge_latest.pt"], | |
["examples/coral_image3.jpg", "yolov8_xlarge_latest.pt"], | |
] | |
# Create and launch the Gradio interface | |
demo_app = gr.Interface( | |
fn=coral_ai_inference, | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
examples=examples, | |
cache_examples=True, | |
theme="default", | |
) | |
demo_app.queue().launch(debug=True) |