Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import os | |
import json | |
import requests | |
import matplotlib.pyplot as plt | |
from ultralytics import YOLO | |
# Coral AI model files hosted in the Hugging Face model repository | |
model_names = [ | |
"yolov8_xlarge_latest.pt" | |
] | |
current_model_name = "yolov8_xlarge_latest.pt" | |
model_dir = "models" | |
examples_dir = "examples" | |
os.makedirs(model_dir, exist_ok=True) | |
os.makedirs(examples_dir, exist_ok=True) | |
# Download models if not already present locally | |
HF_MODELS_REPO = "https://huggingface.co/reefsupport/coral-ai/resolve/main/models" | |
for model_name in model_names: | |
model_path = os.path.join(model_dir, model_name) | |
if not os.path.exists(model_path): | |
print(f"Downloading {model_name}...") | |
model_url = f"{HF_MODELS_REPO}/{model_name}" | |
torch.hub.download_url_to_file(model_url, model_path) | |
# Load the initial model | |
model = YOLO(os.path.join(model_dir, current_model_name)) | |
def download_examples(json_file="examples.json"): | |
""" | |
Download example images specified in the JSON file and save them in the examples directory. | |
""" | |
if not os.path.exists(json_file): | |
print(f"Examples JSON file '{json_file}' not found. Skipping example download.") | |
return | |
with open(json_file, "r") as f: | |
data = json.load(f) | |
for example in data["examples"]: | |
image_name = example["name"].replace(" ", "_").lower() + ".jpg" | |
image_path = os.path.join(examples_dir, image_name) | |
if not os.path.exists(image_path): # Skip if already downloaded | |
print(f"Downloading {example['name']}...") | |
response = requests.get(example["url"], stream=True) | |
if response.status_code == 200: | |
with open(image_path, "wb") as img_file: | |
for chunk in response.iter_content(1024): | |
img_file.write(chunk) | |
print(f"Saved {example['name']} to {image_path}") | |
else: | |
print(f"Failed to download {example['name']}: {response.status_code}") | |
# Download example images | |
download_examples() | |
def compute_class_areas(predictions, image_shape): | |
""" | |
Compute the area percentage covered by each class for the prediction. | |
""" | |
total_pixels = image_shape[0] * image_shape[1] # Total pixels in the image | |
class_areas = {} | |
for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls): | |
class_name = predictions.names[int(cls_id.item())] # Get the class name | |
mask_area = mask.sum().item() # Count non-zero pixels in the mask | |
if class_name not in class_areas: | |
class_areas[class_name] = 0 | |
class_areas[class_name] += mask_area | |
# Convert areas to percentages relative to the total image | |
for class_name in class_areas: | |
class_areas[class_name] = (class_areas[class_name] / total_pixels) * 100 | |
return class_areas | |
def generate_coverage_pie_chart(class_areas): | |
""" | |
Generates a pie chart for class coverage percentages. | |
""" | |
total_percentage = sum(class_areas.values()) | |
other_percentage = max(0, 100 - total_percentage) | |
labels = list(class_areas.keys()) + (["Other"] if other_percentage > 0 else []) | |
sizes = list(class_areas.values()) + ([other_percentage] if other_percentage > 0 else []) | |
plt.figure(figsize=(8, 6)) | |
plt.pie( | |
sizes, | |
labels=[f"{label} ({size:.1f}%)" for label, size in zip(labels, sizes)], | |
autopct="%1.1f%%", | |
startangle=90, | |
colors=plt.cm.tab20.colors | |
) | |
plt.axis("equal") | |
plt.title("Class Coverage Distribution") | |
chart_path = "class_coverage_pie.png" | |
plt.savefig(chart_path) | |
plt.close() | |
return chart_path | |
def coral_ai_inference(image, model_name): | |
""" | |
Perform inference and generate class coverage data. | |
""" | |
global model | |
global current_model_name | |
if model_name != current_model_name: | |
model = YOLO(os.path.join(model_dir, model_name)) | |
current_model_name = model_name | |
results = model(image) # Perform inference | |
# Generate class coverage | |
class_areas = compute_class_areas(results[0], image.shape[:2]) | |
pie_chart_path = generate_coverage_pie_chart(class_areas) | |
# Render the prediction | |
rendered_image = results[0].plot() | |
return rendered_image, pie_chart_path | |
# Dynamically generate Gradio examples | |
def generate_examples(json_file="examples.json"): | |
""" | |
Generate Gradio examples from the examples.json file. | |
""" | |
if not os.path.exists(json_file): | |
return [] | |
with open(json_file, "r") as f: | |
data = json.load(f) | |
return [[os.path.join(examples_dir, example["name"].replace(" ", "_").lower() + ".jpg"), current_model_name] for example in data["examples"]] | |
# Define Gradio interface | |
inputs = [ | |
gr.Image(type="numpy", label="Input Image"), | |
gr.Dropdown(model_names, value=current_model_name, label="Model Type"), | |
] | |
outputs = [ | |
gr.Image(type="numpy", label="Segmented Image"), | |
gr.Image(type="filepath", label="Class Coverage Pie Chart"), | |
] | |
# Icons with links | |
icons = """ | |
<a href="https://reef.support" target="_blank" style="margin-right: 20px;"> | |
<img src="https://img.icons8.com/ios-filled/50/000000/home.png" alt="Home" style="width: 30px; height: 30px;"> | |
</a> | |
<a href="https://docs.reef.support" target="_blank" style="margin-right: 20px;"> | |
<img src="https://img.icons8.com/ios-filled/50/000000/document.png" alt="Documentation" style="width: 30px; height: 30px;"> | |
</a> | |
<a href="https://github.com/reefsupport" target="_blank" style="margin-right: 20px;"> | |
<img src="https://img.icons8.com/ios-filled/50/000000/github.png" alt="GitHub" style="width: 30px; height: 30px;"> | |
</a> | |
<a href="https://url.reef.support/join-discord" target="_blank"> | |
<img src="https://img.icons8.com/ios-filled/50/000000/discord-logo.png" alt="Discord" style="width: 30px; height: 30px;"> | |
</a> | |
""" | |
title = "Coral AI Demo" | |
examples = generate_examples() | |
demo_app = gr.Blocks() | |
with gr.Blocks() as demo: | |
gr.Markdown(icons) | |
gr.Interface( | |
fn=coral_ai_inference, | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
examples=examples, | |
cache_examples=True, | |
) | |
demo.queue().launch(debug=True, share=True) |