Spaces:
Runtime error
Runtime error
Yohan Runhaar
commited on
Commit
Β·
cccde55
1
Parent(s):
381b817
fix demo graph
Browse files
app.py
CHANGED
@@ -30,20 +30,20 @@ def compute_class_areas(predictions, image_shape):
|
|
30 |
"""
|
31 |
Compute the area percentage covered by each class for the prediction.
|
32 |
"""
|
33 |
-
total_pixels = image_shape[0] * image_shape[1]
|
34 |
class_areas = {}
|
35 |
|
36 |
for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls):
|
37 |
-
|
38 |
-
mask_area = mask.sum().item()
|
39 |
|
40 |
-
if
|
41 |
-
class_areas[
|
42 |
-
class_areas[
|
43 |
|
44 |
-
# Convert to percentages
|
45 |
-
for
|
46 |
-
class_areas[
|
47 |
|
48 |
return class_areas
|
49 |
|
@@ -56,9 +56,9 @@ def generate_coverage_graph(class_areas):
|
|
56 |
classes = list(class_areas.keys())
|
57 |
coverage = list(class_areas.values())
|
58 |
plt.bar(classes, coverage, color="skyblue")
|
59 |
-
plt.xlabel("Class
|
60 |
plt.ylabel("Coverage Percentage")
|
61 |
-
plt.title("Class Coverage")
|
62 |
graph_path = "class_coverage.png"
|
63 |
plt.savefig(graph_path)
|
64 |
plt.close()
|
|
|
30 |
"""
|
31 |
Compute the area percentage covered by each class for the prediction.
|
32 |
"""
|
33 |
+
total_pixels = image_shape[0] * image_shape[1] # Total pixels in the image
|
34 |
class_areas = {}
|
35 |
|
36 |
for mask, cls_id in zip(predictions.masks.data, predictions.boxes.cls):
|
37 |
+
class_name = predictions.names[int(cls_id.item())] # Get the class name
|
38 |
+
mask_area = mask.sum().item() # Count non-zero pixels in the mask
|
39 |
|
40 |
+
if class_name not in class_areas:
|
41 |
+
class_areas[class_name] = 0
|
42 |
+
class_areas[class_name] += mask_area
|
43 |
|
44 |
+
# Convert areas to percentages relative to the total image
|
45 |
+
for class_name in class_areas:
|
46 |
+
class_areas[class_name] = (class_areas[class_name] / total_pixels) * 100
|
47 |
|
48 |
return class_areas
|
49 |
|
|
|
56 |
classes = list(class_areas.keys())
|
57 |
coverage = list(class_areas.values())
|
58 |
plt.bar(classes, coverage, color="skyblue")
|
59 |
+
plt.xlabel("Class")
|
60 |
plt.ylabel("Coverage Percentage")
|
61 |
+
plt.title("Class Coverage as % of Total Image")
|
62 |
graph_path = "class_coverage.png"
|
63 |
plt.savefig(graph_path)
|
64 |
plt.close()
|