import gradio as gr import torch import os import json import requests import matplotlib.pyplot as plt from ultralytics import YOLO # Coral AI model files hosted in the Hugging Face model repository model_names = [ "yolov8_xlarge_latest.pt" ] current_model_name = "yolov8_xlarge_latest.pt" model_dir = "models" examples_dir = "examples" os.makedirs(model_dir, exist_ok=True) os.makedirs(examples_dir, exist_ok=True) # Download models if not already present locally HF_MODELS_REPO = "https://huggingface.co/reefsupport/coral-ai/resolve/main/models" for model_name in model_names: model_path = os.path.join(model_dir, model_name) if not os.path.exists(model_path): print(f"Downloading {model_name}...") model_url = f"{HF_MODELS_REPO}/{model_name}" torch.hub.download_url_to_file(model_url, model_path) # Load the initial model model = YOLO(os.path.join(model_dir, current_model_name)) def download_examples(json_file="examples.json"): """ Download example images specified in the JSON file and save them in the examples directory. """ if not os.path.exists(json_file): print(f"Examples JSON file '{json_file}' not found. Skipping example download.") return with open(json_file, "r") as f: data = json.load(f) for example in data["examples"]: image_name = example["name"].replace(" ", "_").lower() + ".jpg" image_path = os.path.join(examples_dir, image_name) if not os.path.exists(image_path): # Skip if already downloaded print(f"Downloading {example['name']}...") response = requests.get(example["url"], stream=True) if response.status_code == 200: with open(image_path, "wb") as img_file: for chunk in response.iter_content(1024): img_file.write(chunk) print(f"Saved {example['name']} to {image_path}") else: print(f"Failed to download {example['name']}: {response.status_code}") # Download example images download_examples() def compute_class_areas(predictions, image_shape): """ Compute the area percentage covered by each class for the prediction. """ total_pixels = image_shape[0] * image_shape[1] # Total pixels in the image class_areas = {} for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls): class_name = predictions.names[int(cls_id.item())] # Get the class name mask_area = mask.sum().item() # Count non-zero pixels in the mask if class_name not in class_areas: class_areas[class_name] = 0 class_areas[class_name] += mask_area # Convert areas to percentages relative to the total image for class_name in class_areas: class_areas[class_name] = (class_areas[class_name] / total_pixels) * 100 return class_areas def generate_coverage_pie_chart(class_areas): """ Generates a pie chart for class coverage percentages. """ total_percentage = sum(class_areas.values()) other_percentage = max(0, 100 - total_percentage) labels = list(class_areas.keys()) + (["Other"] if other_percentage > 0 else []) sizes = list(class_areas.values()) + ([other_percentage] if other_percentage > 0 else []) plt.figure(figsize=(8, 6)) plt.pie( sizes, labels=[f"{label} ({size:.1f}%)" for label, size in zip(labels, sizes)], autopct="%1.1f%%", startangle=90, colors=plt.cm.tab20.colors ) plt.axis("equal") plt.title("Class Coverage Distribution") chart_path = "class_coverage_pie.png" plt.savefig(chart_path) plt.close() return chart_path def coral_ai_inference(image, model_name): """ Perform inference and generate class coverage data. """ global model global current_model_name if model_name != current_model_name: model = YOLO(os.path.join(model_dir, model_name)) current_model_name = model_name results = model(image) # Perform inference # Generate class coverage class_areas = compute_class_areas(results[0], image.shape[:2]) pie_chart_path = generate_coverage_pie_chart(class_areas) # Render the prediction rendered_image = results[0].plot() return rendered_image, pie_chart_path # Dynamically generate Gradio examples def generate_examples(json_file="examples.json"): """ Generate Gradio examples from the examples.json file. """ if not os.path.exists(json_file): return [] with open(json_file, "r") as f: data = json.load(f) return [[os.path.join(examples_dir, example["name"].replace(" ", "_").lower() + ".jpg"), current_model_name] for example in data["examples"]] # Define Gradio interface inputs = [ gr.Image(type="numpy", label="Input Image"), gr.Dropdown(model_names, value=current_model_name, label="Model Type"), ] outputs = [ gr.Image(type="numpy", label="Segmented Image"), gr.Image(type="filepath", label="Class Coverage Pie Chart"), ] # Icons with links icons = """ Home Documentation GitHub Discord """ title = "Coral AI Demo" examples = generate_examples() demo_app = gr.Blocks() with gr.Blocks() as demo: gr.Markdown(icons) gr.Interface( fn=coral_ai_inference, inputs=inputs, outputs=outputs, title=title, examples=examples, cache_examples=True, ) demo.queue().launch(debug=True, share=True)