Yohan Runhaar commited on
Commit
4222ab1
Β·
1 Parent(s): 45d03df
Files changed (1) hide show
  1. app.py +39 -38
app.py CHANGED
@@ -6,12 +6,9 @@ from ultralytics import YOLO
6
 
7
  # Coral AI model files hosted in the Hugging Face model repository
8
  model_names = [
9
- "yolov8_xlarge_latest.pt",
10
- # "yolov8_xlarge_v1.pt",
11
- # "yolov8_xlarge_v2.pt",
12
  ]
13
 
14
- # Set the initial model
15
  current_model_name = "yolov8_xlarge_latest.pt"
16
  model_dir = "models"
17
  os.makedirs(model_dir, exist_ok=True)
@@ -29,39 +26,48 @@ for model_name in model_names:
29
  model = YOLO(os.path.join(model_dir, current_model_name))
30
 
31
 
32
- def generate_coverage_graph(results):
33
  """
34
- Generates a coverage percentage graph for the output classes.
35
  """
36
- # Extract segmentation masks and class labels
37
- masks = results[0].masks.data # Segmentation masks
38
- class_ids = results[0].masks.cls # Class IDs
39
 
40
- # Calculate coverage percentages for each class
41
- total_pixels = masks.sum()
42
- class_percentages = {int(cls): (mask.sum() / total_pixels) * 100 for cls, mask in zip(class_ids, masks)}
 
 
 
43
 
44
- # Generate the graph
 
 
 
 
 
 
 
 
 
 
45
  plt.figure(figsize=(8, 6))
46
- plt.bar(class_percentages.keys(), class_percentages.values())
 
 
47
  plt.xlabel("Class ID")
48
  plt.ylabel("Coverage Percentage")
49
- plt.title("Class Coverage in Segmentation")
50
- plt.xticks(list(class_percentages.keys()))
51
  graph_path = "class_coverage.png"
52
  plt.savefig(graph_path)
53
  plt.close()
54
  return graph_path
55
 
56
 
57
- def coral_ai_inference(image: str, model_name: str):
58
  """
59
- Coral AI inference function with class coverage graph.
60
- Args:
61
- image: Input image filepath
62
- model_name: Name of the model
63
- Returns:
64
- Rendered image and class coverage graph
65
  """
66
  global model
67
  global current_model_name
@@ -69,29 +75,26 @@ def coral_ai_inference(image: str, model_name: str):
69
  model = YOLO(os.path.join(model_dir, model_name))
70
  current_model_name = model_name
71
 
72
- # Perform inference
73
- results = model.predict(image, return_outputs=True)
 
 
 
74
 
75
- # Rendered image
76
  rendered_image = results[0].plot()
77
 
78
- # Generate class coverage graph
79
- graph_path = generate_coverage_graph(results)
80
  return rendered_image, graph_path
81
 
82
 
83
- # Define Gradio inputs and outputs
84
  inputs = [
85
- gr.Image(type="filepath", label="Input Image"),
86
- gr.Dropdown(
87
- model_names,
88
- value=current_model_name,
89
- label="Model Type",
90
- ),
91
  ]
92
 
93
  outputs = [
94
- gr.Image(type="filepath", label="Segmented Image"),
95
  gr.Image(type="filepath", label="Class Coverage Graph"),
96
  ]
97
 
@@ -103,7 +106,6 @@ examples = [
103
  # ["examples/coral_image3.jpg", "yolov8_xlarge_latest.pt"],
104
  ]
105
 
106
- # Create and launch the Gradio interface
107
  demo_app = gr.Interface(
108
  fn=coral_ai_inference,
109
  inputs=inputs,
@@ -111,6 +113,5 @@ demo_app = gr.Interface(
111
  title=title,
112
  examples=examples,
113
  cache_examples=True,
114
- theme="default",
115
  )
116
  demo_app.queue().launch(debug=True)
 
6
 
7
  # Coral AI model files hosted in the Hugging Face model repository
8
  model_names = [
9
+ "yolov8_xlarge_latest.pt"
 
 
10
  ]
11
 
 
12
  current_model_name = "yolov8_xlarge_latest.pt"
13
  model_dir = "models"
14
  os.makedirs(model_dir, exist_ok=True)
 
26
  model = YOLO(os.path.join(model_dir, current_model_name))
27
 
28
 
29
+ def compute_class_areas(predictions, image_shape):
30
  """
31
+ Compute the area percentage covered by each class for the prediction.
32
  """
33
+ total_pixels = image_shape[0] * image_shape[1]
34
+ class_areas = {}
35
+ merged_masks = {}
36
 
37
+ for mask, cls in zip(predictions.masks.data, predictions.masks.cls):
38
+ cls_id = int(cls.item())
39
+ if cls_id not in merged_masks:
40
+ merged_masks[cls_id] = mask
41
+ else:
42
+ merged_masks[cls_id] = torch.logical_or(merged_masks[cls_id], mask)
43
 
44
+ for cls_id, mask in merged_masks.items():
45
+ mask_area = mask.sum().item()
46
+ class_areas[cls_id] = (mask_area / total_pixels) * 100
47
+
48
+ return class_areas
49
+
50
+
51
+ def generate_coverage_graph(class_areas):
52
+ """
53
+ Generates a graph for class coverage percentages.
54
+ """
55
  plt.figure(figsize=(8, 6))
56
+ classes = list(class_areas.keys())
57
+ coverage = list(class_areas.values())
58
+ plt.bar(classes, coverage, color="skyblue")
59
  plt.xlabel("Class ID")
60
  plt.ylabel("Coverage Percentage")
61
+ plt.title("Class Coverage")
 
62
  graph_path = "class_coverage.png"
63
  plt.savefig(graph_path)
64
  plt.close()
65
  return graph_path
66
 
67
 
68
+ def coral_ai_inference(image, model_name):
69
  """
70
+ Perform inference and generate class coverage data.
 
 
 
 
 
71
  """
72
  global model
73
  global current_model_name
 
75
  model = YOLO(os.path.join(model_dir, model_name))
76
  current_model_name = model_name
77
 
78
+ results = model(image) # Perform inference
79
+
80
+ # Generate class coverage
81
+ class_areas = compute_class_areas(results[0], image.shape[:2])
82
+ graph_path = generate_coverage_graph(class_areas)
83
 
84
+ # Render the prediction
85
  rendered_image = results[0].plot()
86
 
 
 
87
  return rendered_image, graph_path
88
 
89
 
90
+ # Define Gradio interface
91
  inputs = [
92
+ gr.Image(type="numpy", label="Input Image"),
93
+ gr.Dropdown(model_names, value=current_model_name, label="Model Type"),
 
 
 
 
94
  ]
95
 
96
  outputs = [
97
+ gr.Image(type="numpy", label="Segmented Image"),
98
  gr.Image(type="filepath", label="Class Coverage Graph"),
99
  ]
100
 
 
106
  # ["examples/coral_image3.jpg", "yolov8_xlarge_latest.pt"],
107
  ]
108
 
 
109
  demo_app = gr.Interface(
110
  fn=coral_ai_inference,
111
  inputs=inputs,
 
113
  title=title,
114
  examples=examples,
115
  cache_examples=True,
 
116
  )
117
  demo_app.queue().launch(debug=True)