Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import matplotlib.pyplot as plt | |
from typing import Dict, List, Any, Optional, Tuple | |
class EvaluationMetrics: | |
"""Class for computing detection metrics, generating statistics and visualization data""" | |
def calculate_basic_stats(result: Any) -> Dict: | |
""" | |
Calculate basic statistics for a single detection result | |
Args: | |
result: Detection result object | |
Returns: | |
Dictionary with basic statistics | |
""" | |
if result is None: | |
return {"error": "No detection result provided"} | |
# Get classes and confidences | |
classes = result.boxes.cls.cpu().numpy().astype(int) | |
confidences = result.boxes.conf.cpu().numpy() | |
names = result.names | |
# Count by class | |
class_counts = {} | |
for cls, conf in zip(classes, confidences): | |
cls_name = names[int(cls)] | |
if cls_name not in class_counts: | |
class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []} | |
class_counts[cls_name]["count"] += 1 | |
class_counts[cls_name]["total_confidence"] += float(conf) | |
class_counts[cls_name]["confidences"].append(float(conf)) | |
# Calculate average confidence | |
for cls_name, stats in class_counts.items(): | |
if stats["count"] > 0: | |
stats["average_confidence"] = stats["total_confidence"] / stats["count"] | |
stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0 | |
stats.pop("total_confidence") # Remove intermediate calculation | |
# Prepare summary | |
stats = { | |
"total_objects": len(classes), | |
"class_statistics": class_counts, | |
"average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0 | |
} | |
return stats | |
def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict: | |
""" | |
Generate structured data suitable for visualization | |
Args: | |
result: Detection result object | |
class_colors: Dictionary mapping class names to color codes (optional) | |
Returns: | |
Dictionary with visualization-ready data | |
""" | |
if result is None: | |
return {"error": "No detection result provided"} | |
# Get basic stats first | |
stats = EvaluationMetrics.calculate_basic_stats(result) | |
# Create visualization-specific data structure | |
viz_data = { | |
"total_objects": stats["total_objects"], | |
"average_confidence": stats["average_confidence"], | |
"class_data": [] | |
} | |
# Sort classes by count (descending) | |
sorted_classes = sorted( | |
stats["class_statistics"].items(), | |
key=lambda x: x[1]["count"], | |
reverse=True | |
) | |
# Create class-specific visualization data | |
for cls_name, cls_stats in sorted_classes: | |
class_id = -1 | |
# Find the class ID based on the name | |
for idx, name in result.names.items(): | |
if name == cls_name: | |
class_id = idx | |
break | |
cls_data = { | |
"name": cls_name, | |
"class_id": class_id, | |
"count": cls_stats["count"], | |
"average_confidence": cls_stats.get("average_confidence", 0), | |
"confidence_std": cls_stats.get("confidence_std", 0), | |
"color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC" | |
} | |
viz_data["class_data"].append(cls_data) | |
return viz_data | |
def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure: | |
""" | |
Create a horizontal bar chart showing detection statistics | |
Args: | |
viz_data: Visualization data generated by generate_visualization_data | |
figsize: Figure size (width, height) in inches | |
max_classes: Maximum number of classes to display | |
Returns: | |
Matplotlib figure object | |
""" | |
# Use the enhanced version | |
return EvaluationMetrics.create_enhanced_stats_plot(viz_data, figsize, max_classes) | |
def create_enhanced_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure: | |
""" | |
Create an enhanced horizontal bar chart with larger fonts and better styling | |
Args: | |
viz_data: Visualization data dictionary | |
figsize: Figure size (width, height) in inches | |
max_classes: Maximum number of classes to display | |
Returns: | |
Matplotlib figure with enhanced styling | |
""" | |
if "error" in viz_data: | |
# Create empty plot if error | |
fig, ax = plt.subplots(figsize=figsize) | |
ax.text(0.5, 0.5, viz_data["error"], | |
ha='center', va='center', fontsize=14, fontfamily='Arial') | |
ax.set_xlim(0, 1) | |
ax.set_ylim(0, 1) | |
ax.axis('off') | |
return fig | |
if "class_data" not in viz_data or not viz_data["class_data"]: | |
# Create empty plot if no data | |
fig, ax = plt.subplots(figsize=figsize) | |
ax.text(0.5, 0.5, "No detection data available", | |
ha='center', va='center', fontsize=14, fontfamily='Arial') | |
ax.set_xlim(0, 1) | |
ax.set_ylim(0, 1) | |
ax.axis('off') | |
return fig | |
# Limit to max_classes | |
class_data = viz_data["class_data"][:max_classes] | |
# Extract data for plotting | |
class_names = [item["name"] for item in class_data] | |
counts = [item["count"] for item in class_data] | |
colors = [item["color"] for item in class_data] | |
# Create figure and horizontal bar chart with improved styling | |
plt.rcParams['font.family'] = 'Arial' | |
fig, ax = plt.subplots(figsize=figsize) | |
# Set background color to white | |
fig.patch.set_facecolor('white') | |
ax.set_facecolor('white') | |
y_pos = np.arange(len(class_names)) | |
# Create horizontal bars with class-specific colors | |
bars = ax.barh(y_pos, counts, color=colors, alpha=0.8, height=0.6) | |
# Add count values at end of each bar with larger font | |
for i, bar in enumerate(bars): | |
width = bar.get_width() | |
conf = class_data[i]["average_confidence"] | |
ax.text(width + 0.3, bar.get_y() + bar.get_height()/2, | |
f"{width:.0f} (conf: {conf:.2f})", | |
va='center', fontsize=12, fontfamily='Arial') | |
# Customize axis and labels with larger fonts | |
ax.set_yticks(y_pos) | |
ax.set_yticklabels(class_names, fontsize=14, fontfamily='Arial') | |
ax.invert_yaxis() # Labels read top-to-bottom | |
ax.set_xlabel('Count', fontsize=14, fontfamily='Arial') | |
ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total', | |
fontsize=16, fontfamily='Arial', fontweight='bold') | |
# Add grid for better readability | |
ax.set_axisbelow(True) | |
ax.grid(axis='x', linestyle='--', alpha=0.7, color='#E5E7EB') | |
# Increase tick label font size | |
ax.tick_params(axis='both', which='major', labelsize=12) | |
# Add detection summary as a text box with improved styling | |
summary_text = ( | |
f"Total Objects: {viz_data['total_objects']}\n" | |
f"Average Confidence: {viz_data['average_confidence']:.2f}\n" | |
f"Unique Classes: {len(viz_data['class_data'])}" | |
) | |
plt.figtext(0.02, 0.02, summary_text, fontsize=12, fontfamily='Arial', | |
bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.5', | |
edgecolor='#E5E7EB')) | |
plt.tight_layout() | |
return fig | |
def format_detection_summary(viz_data: Dict) -> str: | |
if "error" in viz_data: | |
return viz_data["error"] | |
if "total_objects" not in viz_data: | |
return "No detection data available." | |
total_objects = viz_data["total_objects"] | |
avg_confidence = viz_data["average_confidence"] | |
lines = [ | |
f"Detected {total_objects} objects.", | |
f"Average confidence: {avg_confidence:.2f}", | |
"Objects by class:" | |
] | |
if "class_data" in viz_data and viz_data["class_data"]: | |
for item in viz_data["class_data"]: | |
count = item['count'] | |
item_text = "item" if count == 1 else "items" | |
lines.append(f"• {item['name']}: {count} {item_text} (Confidence: {item['average_confidence']:.2f})") | |
else: | |
lines.append("No class information available.") | |
return "\n".join(lines) | |
def calculate_distance_metrics(result: Any) -> Dict: | |
""" | |
Calculate distance-related metrics for detected objects | |
Args: | |
result: Detection result object | |
Returns: | |
Dictionary with distance metrics | |
""" | |
if result is None: | |
return {"error": "No detection result provided"} | |
boxes = result.boxes.xyxy.cpu().numpy() | |
classes = result.boxes.cls.cpu().numpy().astype(int) | |
names = result.names | |
# Initialize metrics | |
metrics = { | |
"proximity": {}, # Classes that appear close to each other | |
"spatial_distribution": {}, # Distribution across the image | |
"size_distribution": {} # Size distribution of objects | |
} | |
# Calculate image dimensions (assuming normalized coordinates or extract from result) | |
img_width, img_height = 1, 1 | |
if hasattr(result, "orig_shape"): | |
img_height, img_width = result.orig_shape[:2] | |
# Calculate bounding box areas and centers | |
areas = [] | |
centers = [] | |
class_names = [] | |
for box, cls in zip(boxes, classes): | |
x1, y1, x2, y2 = box | |
width, height = x2 - x1, y2 - y1 | |
area = width * height | |
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2 | |
areas.append(area) | |
centers.append((center_x, center_y)) | |
class_names.append(names[int(cls)]) | |
# Calculate spatial distribution | |
if centers: | |
x_coords = [c[0] for c in centers] | |
y_coords = [c[1] for c in centers] | |
metrics["spatial_distribution"] = { | |
"x_mean": float(np.mean(x_coords)) / img_width, | |
"y_mean": float(np.mean(y_coords)) / img_height, | |
"x_std": float(np.std(x_coords)) / img_width, | |
"y_std": float(np.std(y_coords)) / img_height | |
} | |
# Calculate size distribution | |
if areas: | |
metrics["size_distribution"] = { | |
"mean_area": float(np.mean(areas)) / (img_width * img_height), | |
"std_area": float(np.std(areas)) / (img_width * img_height), | |
"min_area": float(np.min(areas)) / (img_width * img_height), | |
"max_area": float(np.max(areas)) / (img_width * img_height) | |
} | |
# Calculate proximity between different classes | |
class_centers = {} | |
for cls_name, center in zip(class_names, centers): | |
if cls_name not in class_centers: | |
class_centers[cls_name] = [] | |
class_centers[cls_name].append(center) | |
# Find classes that appear close to each other | |
proximity_pairs = [] | |
for i, cls1 in enumerate(class_centers.keys()): | |
for j, cls2 in enumerate(class_centers.keys()): | |
if i >= j: # Avoid duplicate pairs and self-comparison | |
continue | |
# Calculate minimum distance between any two objects of these classes | |
min_distance = float('inf') | |
for center1 in class_centers[cls1]: | |
for center2 in class_centers[cls2]: | |
dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2) | |
min_distance = min(min_distance, dist) | |
# Normalize by image diagonal | |
img_diagonal = np.sqrt(img_width**2 + img_height**2) | |
norm_distance = min_distance / img_diagonal | |
proximity_pairs.append({ | |
"class1": cls1, | |
"class2": cls2, | |
"distance": float(norm_distance) | |
}) | |
# Sort by distance and keep the closest pairs | |
proximity_pairs.sort(key=lambda x: x["distance"]) | |
metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs | |
return metrics | |