Yohan Runhaar commited on
Commit
cccde55
Β·
1 Parent(s): 381b817

fix demo graph

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -30,20 +30,20 @@ def compute_class_areas(predictions, image_shape):
30
  """
31
  Compute the area percentage covered by each class for the prediction.
32
  """
33
- total_pixels = image_shape[0] * image_shape[1]
34
  class_areas = {}
35
 
36
  for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls):
37
- cls_id = int(cls_id.item())
38
- mask_area = mask.sum().item()
39
 
40
- if cls_id not in class_areas:
41
- class_areas[cls_id] = 0
42
- class_areas[cls_id] += mask_area
43
 
44
- # Convert to percentages
45
- for cls_id in class_areas:
46
- class_areas[cls_id] = (class_areas[cls_id] / total_pixels) * 100
47
 
48
  return class_areas
49
 
@@ -56,9 +56,9 @@ def generate_coverage_graph(class_areas):
56
  classes = list(class_areas.keys())
57
  coverage = list(class_areas.values())
58
  plt.bar(classes, coverage, color="skyblue")
59
- plt.xlabel("Class ID")
60
  plt.ylabel("Coverage Percentage")
61
- plt.title("Class Coverage")
62
  graph_path = "class_coverage.png"
63
  plt.savefig(graph_path)
64
  plt.close()
 
30
  """
31
  Compute the area percentage covered by each class for the prediction.
32
  """
33
+ total_pixels = image_shape[0] * image_shape[1] # Total pixels in the image
34
  class_areas = {}
35
 
36
  for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls):
37
+ class_name = predictions.names[int(cls_id.item())] # Get the class name
38
+ mask_area = mask.sum().item() # Count non-zero pixels in the mask
39
 
40
+ if class_name not in class_areas:
41
+ class_areas[class_name] = 0
42
+ class_areas[class_name] += mask_area
43
 
44
+ # Convert areas to percentages relative to the total image
45
+ for class_name in class_areas:
46
+ class_areas[class_name] = (class_areas[class_name] / total_pixels) * 100
47
 
48
  return class_areas
49
 
 
56
  classes = list(class_areas.keys())
57
  coverage = list(class_areas.values())
58
  plt.bar(classes, coverage, color="skyblue")
59
+ plt.xlabel("Class")
60
  plt.ylabel("Coverage Percentage")
61
+ plt.title("Class Coverage as % of Total Image")
62
  graph_path = "class_coverage.png"
63
  plt.savefig(graph_path)
64
  plt.close()