Spaces:
Runtime error
Runtime error
Yohan Runhaar
commited on
Commit
Β·
c27b4ef
1
Parent(s):
96309ac
fix gradio
Browse files- app.py +39 -6
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
|
|
4 |
from ultralytics import YOLO
|
5 |
|
6 |
# Coral AI model files hosted in the Hugging Face model repository
|
@@ -28,14 +29,39 @@ for model_name in model_names:
|
|
28 |
model = YOLO(os.path.join(model_dir, current_model_name))
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def coral_ai_inference(image: str, model_name: str):
|
32 |
"""
|
33 |
-
Coral AI inference function
|
34 |
Args:
|
35 |
image: Input image filepath
|
36 |
model_name: Name of the model
|
37 |
Returns:
|
38 |
-
Rendered image
|
39 |
"""
|
40 |
global model
|
41 |
global current_model_name
|
@@ -46,9 +72,12 @@ def coral_ai_inference(image: str, model_name: str):
|
|
46 |
# Perform inference
|
47 |
results = model.predict(image, return_outputs=True)
|
48 |
|
49 |
-
#
|
50 |
-
rendered_image = results[0].plot()
|
51 |
-
|
|
|
|
|
|
|
52 |
|
53 |
|
54 |
# Define Gradio inputs and outputs
|
@@ -61,7 +90,11 @@ inputs = [
|
|
61 |
),
|
62 |
]
|
63 |
|
64 |
-
outputs =
|
|
|
|
|
|
|
|
|
65 |
title = "Coral AI YOLOv8 Segmentation Demo"
|
66 |
|
67 |
examples = [
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
from ultralytics import YOLO
|
6 |
|
7 |
# Coral AI model files hosted in the Hugging Face model repository
|
|
|
29 |
model = YOLO(os.path.join(model_dir, current_model_name))
|
30 |
|
31 |
|
32 |
+
def generate_coverage_graph(results):
|
33 |
+
"""
|
34 |
+
Generates a coverage percentage graph for the output classes.
|
35 |
+
"""
|
36 |
+
# Extract segmentation masks and class labels
|
37 |
+
masks = results[0].masks.data # Segmentation masks
|
38 |
+
class_ids = results[0].masks.cls # Class IDs
|
39 |
+
|
40 |
+
# Calculate coverage percentages for each class
|
41 |
+
total_pixels = masks.sum()
|
42 |
+
class_percentages = {int(cls): (mask.sum() / total_pixels) * 100 for cls, mask in zip(class_ids, masks)}
|
43 |
+
|
44 |
+
# Generate the graph
|
45 |
+
plt.figure(figsize=(8, 6))
|
46 |
+
plt.bar(class_percentages.keys(), class_percentages.values())
|
47 |
+
plt.xlabel("Class ID")
|
48 |
+
plt.ylabel("Coverage Percentage")
|
49 |
+
plt.title("Class Coverage in Segmentation")
|
50 |
+
plt.xticks(list(class_percentages.keys()))
|
51 |
+
graph_path = "class_coverage.png"
|
52 |
+
plt.savefig(graph_path)
|
53 |
+
plt.close()
|
54 |
+
return graph_path
|
55 |
+
|
56 |
+
|
57 |
def coral_ai_inference(image: str, model_name: str):
|
58 |
"""
|
59 |
+
Coral AI inference function with class coverage graph.
|
60 |
Args:
|
61 |
image: Input image filepath
|
62 |
model_name: Name of the model
|
63 |
Returns:
|
64 |
+
Rendered image and class coverage graph
|
65 |
"""
|
66 |
global model
|
67 |
global current_model_name
|
|
|
72 |
# Perform inference
|
73 |
results = model.predict(image, return_outputs=True)
|
74 |
|
75 |
+
# Rendered image
|
76 |
+
rendered_image = results[0].plot()
|
77 |
+
|
78 |
+
# Generate class coverage graph
|
79 |
+
graph_path = generate_coverage_graph(results)
|
80 |
+
return rendered_image, graph_path
|
81 |
|
82 |
|
83 |
# Define Gradio inputs and outputs
|
|
|
90 |
),
|
91 |
]
|
92 |
|
93 |
+
outputs = [
|
94 |
+
gr.Image(type="filepath", label="Segmented Image"),
|
95 |
+
gr.Image(type="filepath", label="Class Coverage Graph"),
|
96 |
+
]
|
97 |
+
|
98 |
title = "Coral AI YOLOv8 Segmentation Demo"
|
99 |
|
100 |
examples = [
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
gradio
|
2 |
torch
|
3 |
-
|
|
|
|
1 |
gradio
|
2 |
torch
|
3 |
+
ultralytics
|
4 |
+
matplotlib
|