VisionScout / evaluation_metrics.py
DawnC's picture
Upload 2 files
5888da9 verified
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Any, Optional, Tuple
class EvaluationMetrics:
"""Class for computing detection metrics, generating statistics and visualization data"""
@staticmethod
def calculate_basic_stats(result: Any) -> Dict:
"""
Calculate basic statistics for a single detection result
Args:
result: Detection result object
Returns:
Dictionary with basic statistics
"""
if result is None:
return {"error": "No detection result provided"}
# Get classes and confidences
classes = result.boxes.cls.cpu().numpy().astype(int)
confidences = result.boxes.conf.cpu().numpy()
names = result.names
# Count by class
class_counts = {}
for cls, conf in zip(classes, confidences):
cls_name = names[int(cls)]
if cls_name not in class_counts:
class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
class_counts[cls_name]["count"] += 1
class_counts[cls_name]["total_confidence"] += float(conf)
class_counts[cls_name]["confidences"].append(float(conf))
# Calculate average confidence
for cls_name, stats in class_counts.items():
if stats["count"] > 0:
stats["average_confidence"] = stats["total_confidence"] / stats["count"]
stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
stats.pop("total_confidence") # Remove intermediate calculation
# Prepare summary
stats = {
"total_objects": len(classes),
"class_statistics": class_counts,
"average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
}
return stats
@staticmethod
def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
"""
Generate structured data suitable for visualization
Args:
result: Detection result object
class_colors: Dictionary mapping class names to color codes (optional)
Returns:
Dictionary with visualization-ready data
"""
if result is None:
return {"error": "No detection result provided"}
# Get basic stats first
stats = EvaluationMetrics.calculate_basic_stats(result)
# Create visualization-specific data structure
viz_data = {
"total_objects": stats["total_objects"],
"average_confidence": stats["average_confidence"],
"class_data": []
}
# Sort classes by count (descending)
sorted_classes = sorted(
stats["class_statistics"].items(),
key=lambda x: x[1]["count"],
reverse=True
)
# Create class-specific visualization data
for cls_name, cls_stats in sorted_classes:
class_id = -1
# Find the class ID based on the name
for idx, name in result.names.items():
if name == cls_name:
class_id = idx
break
cls_data = {
"name": cls_name,
"class_id": class_id,
"count": cls_stats["count"],
"average_confidence": cls_stats.get("average_confidence", 0),
"confidence_std": cls_stats.get("confidence_std", 0),
"color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
}
viz_data["class_data"].append(cls_data)
return viz_data
@staticmethod
def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
"""
Create a horizontal bar chart showing detection statistics
Args:
viz_data: Visualization data generated by generate_visualization_data
figsize: Figure size (width, height) in inches
max_classes: Maximum number of classes to display
Returns:
Matplotlib figure object
"""
# Use the enhanced version
return EvaluationMetrics.create_enhanced_stats_plot(viz_data, figsize, max_classes)
@staticmethod
def create_enhanced_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
"""
Create an enhanced horizontal bar chart with larger fonts and better styling
Args:
viz_data: Visualization data dictionary
figsize: Figure size (width, height) in inches
max_classes: Maximum number of classes to display
Returns:
Matplotlib figure with enhanced styling
"""
if "error" in viz_data:
# Create empty plot if error
fig, ax = plt.subplots(figsize=figsize)
ax.text(0.5, 0.5, viz_data["error"],
ha='center', va='center', fontsize=14, fontfamily='Arial')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')
return fig
if "class_data" not in viz_data or not viz_data["class_data"]:
# Create empty plot if no data
fig, ax = plt.subplots(figsize=figsize)
ax.text(0.5, 0.5, "No detection data available",
ha='center', va='center', fontsize=14, fontfamily='Arial')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')
return fig
# Limit to max_classes
class_data = viz_data["class_data"][:max_classes]
# Extract data for plotting
class_names = [item["name"] for item in class_data]
counts = [item["count"] for item in class_data]
colors = [item["color"] for item in class_data]
# Create figure and horizontal bar chart with improved styling
plt.rcParams['font.family'] = 'Arial'
fig, ax = plt.subplots(figsize=figsize)
# Set background color to white
fig.patch.set_facecolor('white')
ax.set_facecolor('white')
y_pos = np.arange(len(class_names))
# Create horizontal bars with class-specific colors
bars = ax.barh(y_pos, counts, color=colors, alpha=0.8, height=0.6)
# Add count values at end of each bar with larger font
for i, bar in enumerate(bars):
width = bar.get_width()
conf = class_data[i]["average_confidence"]
ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
f"{width:.0f} (conf: {conf:.2f})",
va='center', fontsize=12, fontfamily='Arial')
# Customize axis and labels with larger fonts
ax.set_yticks(y_pos)
ax.set_yticklabels(class_names, fontsize=14, fontfamily='Arial')
ax.invert_yaxis() # Labels read top-to-bottom
ax.set_xlabel('Count', fontsize=14, fontfamily='Arial')
ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total',
fontsize=16, fontfamily='Arial', fontweight='bold')
# Add grid for better readability
ax.set_axisbelow(True)
ax.grid(axis='x', linestyle='--', alpha=0.7, color='#E5E7EB')
# Increase tick label font size
ax.tick_params(axis='both', which='major', labelsize=12)
# Add detection summary as a text box with improved styling
summary_text = (
f"Total Objects: {viz_data['total_objects']}\n"
f"Average Confidence: {viz_data['average_confidence']:.2f}\n"
f"Unique Classes: {len(viz_data['class_data'])}"
)
plt.figtext(0.02, 0.02, summary_text, fontsize=12, fontfamily='Arial',
bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.5',
edgecolor='#E5E7EB'))
plt.tight_layout()
return fig
@staticmethod
def format_detection_summary(viz_data: Dict) -> str:
if "error" in viz_data:
return viz_data["error"]
if "total_objects" not in viz_data:
return "No detection data available."
total_objects = viz_data["total_objects"]
avg_confidence = viz_data["average_confidence"]
lines = [
f"Detected {total_objects} objects.",
f"Average confidence: {avg_confidence:.2f}",
"Objects by class:"
]
if "class_data" in viz_data and viz_data["class_data"]:
for item in viz_data["class_data"]:
count = item['count']
item_text = "item" if count == 1 else "items"
lines.append(f"• {item['name']}: {count} {item_text} (Confidence: {item['average_confidence']:.2f})")
else:
lines.append("No class information available.")
return "\n".join(lines)
@staticmethod
def calculate_distance_metrics(result: Any) -> Dict:
"""
Calculate distance-related metrics for detected objects
Args:
result: Detection result object
Returns:
Dictionary with distance metrics
"""
if result is None:
return {"error": "No detection result provided"}
boxes = result.boxes.xyxy.cpu().numpy()
classes = result.boxes.cls.cpu().numpy().astype(int)
names = result.names
# Initialize metrics
metrics = {
"proximity": {}, # Classes that appear close to each other
"spatial_distribution": {}, # Distribution across the image
"size_distribution": {} # Size distribution of objects
}
# Calculate image dimensions (assuming normalized coordinates or extract from result)
img_width, img_height = 1, 1
if hasattr(result, "orig_shape"):
img_height, img_width = result.orig_shape[:2]
# Calculate bounding box areas and centers
areas = []
centers = []
class_names = []
for box, cls in zip(boxes, classes):
x1, y1, x2, y2 = box
width, height = x2 - x1, y2 - y1
area = width * height
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
areas.append(area)
centers.append((center_x, center_y))
class_names.append(names[int(cls)])
# Calculate spatial distribution
if centers:
x_coords = [c[0] for c in centers]
y_coords = [c[1] for c in centers]
metrics["spatial_distribution"] = {
"x_mean": float(np.mean(x_coords)) / img_width,
"y_mean": float(np.mean(y_coords)) / img_height,
"x_std": float(np.std(x_coords)) / img_width,
"y_std": float(np.std(y_coords)) / img_height
}
# Calculate size distribution
if areas:
metrics["size_distribution"] = {
"mean_area": float(np.mean(areas)) / (img_width * img_height),
"std_area": float(np.std(areas)) / (img_width * img_height),
"min_area": float(np.min(areas)) / (img_width * img_height),
"max_area": float(np.max(areas)) / (img_width * img_height)
}
# Calculate proximity between different classes
class_centers = {}
for cls_name, center in zip(class_names, centers):
if cls_name not in class_centers:
class_centers[cls_name] = []
class_centers[cls_name].append(center)
# Find classes that appear close to each other
proximity_pairs = []
for i, cls1 in enumerate(class_centers.keys()):
for j, cls2 in enumerate(class_centers.keys()):
if i >= j: # Avoid duplicate pairs and self-comparison
continue
# Calculate minimum distance between any two objects of these classes
min_distance = float('inf')
for center1 in class_centers[cls1]:
for center2 in class_centers[cls2]:
dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
min_distance = min(min_distance, dist)
# Normalize by image diagonal
img_diagonal = np.sqrt(img_width**2 + img_height**2)
norm_distance = min_distance / img_diagonal
proximity_pairs.append({
"class1": cls1,
"class2": cls2,
"distance": float(norm_distance)
})
# Sort by distance and keep the closest pairs
proximity_pairs.sort(key=lambda x: x["distance"])
metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
return metrics