import numpy as np import matplotlib.pyplot as plt from typing import Dict, List, Any, Optional, Tuple class EvaluationMetrics: """Class for computing detection metrics, generating statistics and visualization data""" @staticmethod def calculate_basic_stats(result: Any) -> Dict: """ Calculate basic statistics for a single detection result Args: result: Detection result object Returns: Dictionary with basic statistics """ if result is None: return {"error": "No detection result provided"} # Get classes and confidences classes = result.boxes.cls.cpu().numpy().astype(int) confidences = result.boxes.conf.cpu().numpy() names = result.names # Count by class class_counts = {} for cls, conf in zip(classes, confidences): cls_name = names[int(cls)] if cls_name not in class_counts: class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []} class_counts[cls_name]["count"] += 1 class_counts[cls_name]["total_confidence"] += float(conf) class_counts[cls_name]["confidences"].append(float(conf)) # Calculate average confidence for cls_name, stats in class_counts.items(): if stats["count"] > 0: stats["average_confidence"] = stats["total_confidence"] / stats["count"] stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0 stats.pop("total_confidence") # Remove intermediate calculation # Prepare summary stats = { "total_objects": len(classes), "class_statistics": class_counts, "average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0 } return stats @staticmethod def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict: """ Generate structured data suitable for visualization Args: result: Detection result object class_colors: Dictionary mapping class names to color codes (optional) Returns: Dictionary with visualization-ready data """ if result is None: return {"error": "No detection result provided"} # Get basic stats first stats = EvaluationMetrics.calculate_basic_stats(result) # Create visualization-specific data structure viz_data = { "total_objects": stats["total_objects"], "average_confidence": stats["average_confidence"], "class_data": [] } # Sort classes by count (descending) sorted_classes = sorted( stats["class_statistics"].items(), key=lambda x: x[1]["count"], reverse=True ) # Create class-specific visualization data for cls_name, cls_stats in sorted_classes: class_id = -1 # Find the class ID based on the name for idx, name in result.names.items(): if name == cls_name: class_id = idx break cls_data = { "name": cls_name, "class_id": class_id, "count": cls_stats["count"], "average_confidence": cls_stats.get("average_confidence", 0), "confidence_std": cls_stats.get("confidence_std", 0), "color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC" } viz_data["class_data"].append(cls_data) return viz_data @staticmethod def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure: """ Create a horizontal bar chart showing detection statistics Args: viz_data: Visualization data generated by generate_visualization_data figsize: Figure size (width, height) in inches max_classes: Maximum number of classes to display Returns: Matplotlib figure object """ # Use the enhanced version return EvaluationMetrics.create_enhanced_stats_plot(viz_data, figsize, max_classes) @staticmethod def create_enhanced_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure: """ Create an enhanced horizontal bar chart with larger fonts and better styling Args: viz_data: Visualization data dictionary figsize: Figure size (width, height) in inches max_classes: Maximum number of classes to display Returns: Matplotlib figure with enhanced styling """ if "error" in viz_data: # Create empty plot if error fig, ax = plt.subplots(figsize=figsize) ax.text(0.5, 0.5, viz_data["error"], ha='center', va='center', fontsize=14, fontfamily='Arial') ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis('off') return fig if "class_data" not in viz_data or not viz_data["class_data"]: # Create empty plot if no data fig, ax = plt.subplots(figsize=figsize) ax.text(0.5, 0.5, "No detection data available", ha='center', va='center', fontsize=14, fontfamily='Arial') ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis('off') return fig # Limit to max_classes class_data = viz_data["class_data"][:max_classes] # Extract data for plotting class_names = [item["name"] for item in class_data] counts = [item["count"] for item in class_data] colors = [item["color"] for item in class_data] # Create figure and horizontal bar chart with improved styling plt.rcParams['font.family'] = 'Arial' fig, ax = plt.subplots(figsize=figsize) # Set background color to white fig.patch.set_facecolor('white') ax.set_facecolor('white') y_pos = np.arange(len(class_names)) # Create horizontal bars with class-specific colors bars = ax.barh(y_pos, counts, color=colors, alpha=0.8, height=0.6) # Add count values at end of each bar with larger font for i, bar in enumerate(bars): width = bar.get_width() conf = class_data[i]["average_confidence"] ax.text(width + 0.3, bar.get_y() + bar.get_height()/2, f"{width:.0f} (conf: {conf:.2f})", va='center', fontsize=12, fontfamily='Arial') # Customize axis and labels with larger fonts ax.set_yticks(y_pos) ax.set_yticklabels(class_names, fontsize=14, fontfamily='Arial') ax.invert_yaxis() # Labels read top-to-bottom ax.set_xlabel('Count', fontsize=14, fontfamily='Arial') ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total', fontsize=16, fontfamily='Arial', fontweight='bold') # Add grid for better readability ax.set_axisbelow(True) ax.grid(axis='x', linestyle='--', alpha=0.7, color='#E5E7EB') # Increase tick label font size ax.tick_params(axis='both', which='major', labelsize=12) # Add detection summary as a text box with improved styling summary_text = ( f"Total Objects: {viz_data['total_objects']}\n" f"Average Confidence: {viz_data['average_confidence']:.2f}\n" f"Unique Classes: {len(viz_data['class_data'])}" ) plt.figtext(0.02, 0.02, summary_text, fontsize=12, fontfamily='Arial', bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.5', edgecolor='#E5E7EB')) plt.tight_layout() return fig @staticmethod def format_detection_summary(viz_data: Dict) -> str: if "error" in viz_data: return viz_data["error"] if "total_objects" not in viz_data: return "No detection data available." total_objects = viz_data["total_objects"] avg_confidence = viz_data["average_confidence"] lines = [ f"Detected {total_objects} objects.", f"Average confidence: {avg_confidence:.2f}", "Objects by class:" ] if "class_data" in viz_data and viz_data["class_data"]: for item in viz_data["class_data"]: count = item['count'] item_text = "item" if count == 1 else "items" lines.append(f"• {item['name']}: {count} {item_text} (Confidence: {item['average_confidence']:.2f})") else: lines.append("No class information available.") return "\n".join(lines) @staticmethod def calculate_distance_metrics(result: Any) -> Dict: """ Calculate distance-related metrics for detected objects Args: result: Detection result object Returns: Dictionary with distance metrics """ if result is None: return {"error": "No detection result provided"} boxes = result.boxes.xyxy.cpu().numpy() classes = result.boxes.cls.cpu().numpy().astype(int) names = result.names # Initialize metrics metrics = { "proximity": {}, # Classes that appear close to each other "spatial_distribution": {}, # Distribution across the image "size_distribution": {} # Size distribution of objects } # Calculate image dimensions (assuming normalized coordinates or extract from result) img_width, img_height = 1, 1 if hasattr(result, "orig_shape"): img_height, img_width = result.orig_shape[:2] # Calculate bounding box areas and centers areas = [] centers = [] class_names = [] for box, cls in zip(boxes, classes): x1, y1, x2, y2 = box width, height = x2 - x1, y2 - y1 area = width * height center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2 areas.append(area) centers.append((center_x, center_y)) class_names.append(names[int(cls)]) # Calculate spatial distribution if centers: x_coords = [c[0] for c in centers] y_coords = [c[1] for c in centers] metrics["spatial_distribution"] = { "x_mean": float(np.mean(x_coords)) / img_width, "y_mean": float(np.mean(y_coords)) / img_height, "x_std": float(np.std(x_coords)) / img_width, "y_std": float(np.std(y_coords)) / img_height } # Calculate size distribution if areas: metrics["size_distribution"] = { "mean_area": float(np.mean(areas)) / (img_width * img_height), "std_area": float(np.std(areas)) / (img_width * img_height), "min_area": float(np.min(areas)) / (img_width * img_height), "max_area": float(np.max(areas)) / (img_width * img_height) } # Calculate proximity between different classes class_centers = {} for cls_name, center in zip(class_names, centers): if cls_name not in class_centers: class_centers[cls_name] = [] class_centers[cls_name].append(center) # Find classes that appear close to each other proximity_pairs = [] for i, cls1 in enumerate(class_centers.keys()): for j, cls2 in enumerate(class_centers.keys()): if i >= j: # Avoid duplicate pairs and self-comparison continue # Calculate minimum distance between any two objects of these classes min_distance = float('inf') for center1 in class_centers[cls1]: for center2 in class_centers[cls2]: dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2) min_distance = min(min_distance, dist) # Normalize by image diagonal img_diagonal = np.sqrt(img_width**2 + img_height**2) norm_distance = min_distance / img_diagonal proximity_pairs.append({ "class1": cls1, "class2": cls2, "distance": float(norm_distance) }) # Sort by distance and keep the closest pairs proximity_pairs.sort(key=lambda x: x["distance"]) metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs return metrics