Spaces:
Running
on
Zero
Running
on
Zero
Upload 2 files
Browse files- evaluation_metrics.py +72 -90
- visualization_helper.py +37 -28
evaluation_metrics.py
CHANGED
@@ -4,85 +4,85 @@ from typing import Dict, List, Any, Optional, Tuple
|
|
4 |
|
5 |
class EvaluationMetrics:
|
6 |
"""Class for computing detection metrics, generating statistics and visualization data"""
|
7 |
-
|
8 |
@staticmethod
|
9 |
def calculate_basic_stats(result: Any) -> Dict:
|
10 |
"""
|
11 |
Calculate basic statistics for a single detection result
|
12 |
-
|
13 |
Args:
|
14 |
result: Detection result object
|
15 |
-
|
16 |
Returns:
|
17 |
Dictionary with basic statistics
|
18 |
"""
|
19 |
if result is None:
|
20 |
return {"error": "No detection result provided"}
|
21 |
-
|
22 |
# Get classes and confidences
|
23 |
classes = result.boxes.cls.cpu().numpy().astype(int)
|
24 |
confidences = result.boxes.conf.cpu().numpy()
|
25 |
names = result.names
|
26 |
-
|
27 |
# Count by class
|
28 |
class_counts = {}
|
29 |
for cls, conf in zip(classes, confidences):
|
30 |
cls_name = names[int(cls)]
|
31 |
if cls_name not in class_counts:
|
32 |
class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
|
33 |
-
|
34 |
class_counts[cls_name]["count"] += 1
|
35 |
class_counts[cls_name]["total_confidence"] += float(conf)
|
36 |
class_counts[cls_name]["confidences"].append(float(conf))
|
37 |
-
|
38 |
# Calculate average confidence
|
39 |
for cls_name, stats in class_counts.items():
|
40 |
if stats["count"] > 0:
|
41 |
stats["average_confidence"] = stats["total_confidence"] / stats["count"]
|
42 |
stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
|
43 |
stats.pop("total_confidence") # Remove intermediate calculation
|
44 |
-
|
45 |
# Prepare summary
|
46 |
stats = {
|
47 |
"total_objects": len(classes),
|
48 |
"class_statistics": class_counts,
|
49 |
"average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
|
50 |
}
|
51 |
-
|
52 |
return stats
|
53 |
-
|
54 |
@staticmethod
|
55 |
def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
|
56 |
"""
|
57 |
Generate structured data suitable for visualization
|
58 |
-
|
59 |
Args:
|
60 |
result: Detection result object
|
61 |
class_colors: Dictionary mapping class names to color codes (optional)
|
62 |
-
|
63 |
Returns:
|
64 |
Dictionary with visualization-ready data
|
65 |
"""
|
66 |
if result is None:
|
67 |
return {"error": "No detection result provided"}
|
68 |
-
|
69 |
# Get basic stats first
|
70 |
stats = EvaluationMetrics.calculate_basic_stats(result)
|
71 |
-
|
72 |
# Create visualization-specific data structure
|
73 |
viz_data = {
|
74 |
"total_objects": stats["total_objects"],
|
75 |
"average_confidence": stats["average_confidence"],
|
76 |
"class_data": []
|
77 |
}
|
78 |
-
|
79 |
# Sort classes by count (descending)
|
80 |
sorted_classes = sorted(
|
81 |
stats["class_statistics"].items(),
|
82 |
key=lambda x: x[1]["count"],
|
83 |
reverse=True
|
84 |
)
|
85 |
-
|
86 |
# Create class-specific visualization data
|
87 |
for cls_name, cls_stats in sorted_classes:
|
88 |
class_id = -1
|
@@ -91,7 +91,7 @@ class EvaluationMetrics:
|
|
91 |
if name == cls_name:
|
92 |
class_id = idx
|
93 |
break
|
94 |
-
|
95 |
cls_data = {
|
96 |
"name": cls_name,
|
97 |
"class_id": class_id,
|
@@ -100,21 +100,21 @@ class EvaluationMetrics:
|
|
100 |
"confidence_std": cls_stats.get("confidence_std", 0),
|
101 |
"color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
|
102 |
}
|
103 |
-
|
104 |
viz_data["class_data"].append(cls_data)
|
105 |
-
|
106 |
return viz_data
|
107 |
-
|
108 |
@staticmethod
|
109 |
def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
|
110 |
"""
|
111 |
Create a horizontal bar chart showing detection statistics
|
112 |
-
|
113 |
Args:
|
114 |
viz_data: Visualization data generated by generate_visualization_data
|
115 |
figsize: Figure size (width, height) in inches
|
116 |
max_classes: Maximum number of classes to display
|
117 |
-
|
118 |
Returns:
|
119 |
Matplotlib figure object
|
120 |
"""
|
@@ -125,79 +125,79 @@ class EvaluationMetrics:
|
|
125 |
def create_enhanced_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
|
126 |
"""
|
127 |
Create an enhanced horizontal bar chart with larger fonts and better styling
|
128 |
-
|
129 |
Args:
|
130 |
viz_data: Visualization data dictionary
|
131 |
figsize: Figure size (width, height) in inches
|
132 |
max_classes: Maximum number of classes to display
|
133 |
-
|
134 |
Returns:
|
135 |
Matplotlib figure with enhanced styling
|
136 |
"""
|
137 |
if "error" in viz_data:
|
138 |
# Create empty plot if error
|
139 |
fig, ax = plt.subplots(figsize=figsize)
|
140 |
-
ax.text(0.5, 0.5, viz_data["error"],
|
141 |
ha='center', va='center', fontsize=14, fontfamily='Arial')
|
142 |
ax.set_xlim(0, 1)
|
143 |
ax.set_ylim(0, 1)
|
144 |
ax.axis('off')
|
145 |
return fig
|
146 |
-
|
147 |
if "class_data" not in viz_data or not viz_data["class_data"]:
|
148 |
# Create empty plot if no data
|
149 |
fig, ax = plt.subplots(figsize=figsize)
|
150 |
-
ax.text(0.5, 0.5, "No detection data available",
|
151 |
ha='center', va='center', fontsize=14, fontfamily='Arial')
|
152 |
ax.set_xlim(0, 1)
|
153 |
ax.set_ylim(0, 1)
|
154 |
ax.axis('off')
|
155 |
return fig
|
156 |
-
|
157 |
# Limit to max_classes
|
158 |
class_data = viz_data["class_data"][:max_classes]
|
159 |
-
|
160 |
# Extract data for plotting
|
161 |
class_names = [item["name"] for item in class_data]
|
162 |
counts = [item["count"] for item in class_data]
|
163 |
colors = [item["color"] for item in class_data]
|
164 |
-
|
165 |
# Create figure and horizontal bar chart with improved styling
|
166 |
plt.rcParams['font.family'] = 'Arial'
|
167 |
fig, ax = plt.subplots(figsize=figsize)
|
168 |
-
|
169 |
# Set background color to white
|
170 |
fig.patch.set_facecolor('white')
|
171 |
ax.set_facecolor('white')
|
172 |
-
|
173 |
y_pos = np.arange(len(class_names))
|
174 |
-
|
175 |
# Create horizontal bars with class-specific colors
|
176 |
bars = ax.barh(y_pos, counts, color=colors, alpha=0.8, height=0.6)
|
177 |
-
|
178 |
# Add count values at end of each bar with larger font
|
179 |
for i, bar in enumerate(bars):
|
180 |
width = bar.get_width()
|
181 |
conf = class_data[i]["average_confidence"]
|
182 |
-
ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
|
183 |
-
f"{width:.0f} (conf: {conf:.2f})",
|
184 |
va='center', fontsize=12, fontfamily='Arial')
|
185 |
-
|
186 |
# Customize axis and labels with larger fonts
|
187 |
ax.set_yticks(y_pos)
|
188 |
ax.set_yticklabels(class_names, fontsize=14, fontfamily='Arial')
|
189 |
ax.invert_yaxis() # Labels read top-to-bottom
|
190 |
ax.set_xlabel('Count', fontsize=14, fontfamily='Arial')
|
191 |
-
ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total',
|
192 |
fontsize=16, fontfamily='Arial', fontweight='bold')
|
193 |
-
|
194 |
# Add grid for better readability
|
195 |
ax.set_axisbelow(True)
|
196 |
ax.grid(axis='x', linestyle='--', alpha=0.7, color='#E5E7EB')
|
197 |
-
|
198 |
# Increase tick label font size
|
199 |
ax.tick_params(axis='both', which='major', labelsize=12)
|
200 |
-
|
201 |
# Add detection summary as a text box with improved styling
|
202 |
summary_text = (
|
203 |
f"Total Objects: {viz_data['total_objects']}\n"
|
@@ -205,114 +205,96 @@ class EvaluationMetrics:
|
|
205 |
f"Unique Classes: {len(viz_data['class_data'])}"
|
206 |
)
|
207 |
plt.figtext(0.02, 0.02, summary_text, fontsize=12, fontfamily='Arial',
|
208 |
-
bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.5',
|
209 |
edgecolor='#E5E7EB'))
|
210 |
-
|
211 |
plt.tight_layout()
|
212 |
return fig
|
213 |
-
|
214 |
@staticmethod
|
215 |
def format_detection_summary(viz_data: Dict) -> str:
|
216 |
-
"""
|
217 |
-
Format detection results as a readable text summary with improved spacing
|
218 |
-
|
219 |
-
Args:
|
220 |
-
viz_data: Visualization data generated by generate_visualization_data
|
221 |
-
|
222 |
-
Returns:
|
223 |
-
Formatted text with proper spacing
|
224 |
-
"""
|
225 |
if "error" in viz_data:
|
226 |
return viz_data["error"]
|
227 |
-
|
228 |
if "total_objects" not in viz_data:
|
229 |
return "No detection data available."
|
230 |
-
|
231 |
-
# 獲取基本統計信息
|
232 |
total_objects = viz_data["total_objects"]
|
233 |
avg_confidence = viz_data["average_confidence"]
|
234 |
-
|
235 |
-
# 創建標題,使用更多空白行增加可讀性
|
236 |
lines = [
|
237 |
-
f"Detected {total_objects} objects.",
|
238 |
f"Average confidence: {avg_confidence:.2f}",
|
239 |
-
"
|
240 |
-
"Objects by class:",
|
241 |
]
|
242 |
-
|
243 |
-
# 添加類別詳情,每個類別使用更多空間
|
244 |
if "class_data" in viz_data and viz_data["class_data"]:
|
245 |
for item in viz_data["class_data"]:
|
246 |
count = item['count']
|
247 |
-
# 使用正確的單複數形式
|
248 |
item_text = "item" if count == 1 else "items"
|
249 |
-
|
250 |
-
# 每個項目前添加空行,並使用縮進格式化
|
251 |
-
lines.append("\n") # 每個項目前添加空白行
|
252 |
-
lines.append(f"• {item['name']}: {count} {item_text}")
|
253 |
-
lines.append(f" Confidence: {item['average_confidence']:.2f}")
|
254 |
else:
|
255 |
-
lines.append("
|
256 |
-
|
257 |
return "\n".join(lines)
|
258 |
-
|
259 |
@staticmethod
|
260 |
def calculate_distance_metrics(result: Any) -> Dict:
|
261 |
"""
|
262 |
Calculate distance-related metrics for detected objects
|
263 |
-
|
264 |
Args:
|
265 |
result: Detection result object
|
266 |
-
|
267 |
Returns:
|
268 |
Dictionary with distance metrics
|
269 |
"""
|
270 |
if result is None:
|
271 |
return {"error": "No detection result provided"}
|
272 |
-
|
273 |
boxes = result.boxes.xyxy.cpu().numpy()
|
274 |
classes = result.boxes.cls.cpu().numpy().astype(int)
|
275 |
names = result.names
|
276 |
-
|
277 |
# Initialize metrics
|
278 |
metrics = {
|
279 |
"proximity": {}, # Classes that appear close to each other
|
280 |
"spatial_distribution": {}, # Distribution across the image
|
281 |
"size_distribution": {} # Size distribution of objects
|
282 |
}
|
283 |
-
|
284 |
# Calculate image dimensions (assuming normalized coordinates or extract from result)
|
285 |
img_width, img_height = 1, 1
|
286 |
if hasattr(result, "orig_shape"):
|
287 |
img_height, img_width = result.orig_shape[:2]
|
288 |
-
|
289 |
# Calculate bounding box areas and centers
|
290 |
areas = []
|
291 |
centers = []
|
292 |
class_names = []
|
293 |
-
|
294 |
for box, cls in zip(boxes, classes):
|
295 |
x1, y1, x2, y2 = box
|
296 |
width, height = x2 - x1, y2 - y1
|
297 |
area = width * height
|
298 |
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
299 |
-
|
300 |
areas.append(area)
|
301 |
centers.append((center_x, center_y))
|
302 |
class_names.append(names[int(cls)])
|
303 |
-
|
304 |
# Calculate spatial distribution
|
305 |
if centers:
|
306 |
x_coords = [c[0] for c in centers]
|
307 |
y_coords = [c[1] for c in centers]
|
308 |
-
|
309 |
metrics["spatial_distribution"] = {
|
310 |
"x_mean": float(np.mean(x_coords)) / img_width,
|
311 |
"y_mean": float(np.mean(y_coords)) / img_height,
|
312 |
"x_std": float(np.std(x_coords)) / img_width,
|
313 |
"y_std": float(np.std(y_coords)) / img_height
|
314 |
}
|
315 |
-
|
316 |
# Calculate size distribution
|
317 |
if areas:
|
318 |
metrics["size_distribution"] = {
|
@@ -321,40 +303,40 @@ class EvaluationMetrics:
|
|
321 |
"min_area": float(np.min(areas)) / (img_width * img_height),
|
322 |
"max_area": float(np.max(areas)) / (img_width * img_height)
|
323 |
}
|
324 |
-
|
325 |
# Calculate proximity between different classes
|
326 |
class_centers = {}
|
327 |
for cls_name, center in zip(class_names, centers):
|
328 |
if cls_name not in class_centers:
|
329 |
class_centers[cls_name] = []
|
330 |
class_centers[cls_name].append(center)
|
331 |
-
|
332 |
# Find classes that appear close to each other
|
333 |
proximity_pairs = []
|
334 |
for i, cls1 in enumerate(class_centers.keys()):
|
335 |
for j, cls2 in enumerate(class_centers.keys()):
|
336 |
if i >= j: # Avoid duplicate pairs and self-comparison
|
337 |
continue
|
338 |
-
|
339 |
# Calculate minimum distance between any two objects of these classes
|
340 |
min_distance = float('inf')
|
341 |
for center1 in class_centers[cls1]:
|
342 |
for center2 in class_centers[cls2]:
|
343 |
dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
|
344 |
min_distance = min(min_distance, dist)
|
345 |
-
|
346 |
# Normalize by image diagonal
|
347 |
img_diagonal = np.sqrt(img_width**2 + img_height**2)
|
348 |
norm_distance = min_distance / img_diagonal
|
349 |
-
|
350 |
proximity_pairs.append({
|
351 |
"class1": cls1,
|
352 |
"class2": cls2,
|
353 |
"distance": float(norm_distance)
|
354 |
})
|
355 |
-
|
356 |
# Sort by distance and keep the closest pairs
|
357 |
proximity_pairs.sort(key=lambda x: x["distance"])
|
358 |
metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
|
359 |
-
|
360 |
return metrics
|
|
|
4 |
|
5 |
class EvaluationMetrics:
|
6 |
"""Class for computing detection metrics, generating statistics and visualization data"""
|
7 |
+
|
8 |
@staticmethod
|
9 |
def calculate_basic_stats(result: Any) -> Dict:
|
10 |
"""
|
11 |
Calculate basic statistics for a single detection result
|
12 |
+
|
13 |
Args:
|
14 |
result: Detection result object
|
15 |
+
|
16 |
Returns:
|
17 |
Dictionary with basic statistics
|
18 |
"""
|
19 |
if result is None:
|
20 |
return {"error": "No detection result provided"}
|
21 |
+
|
22 |
# Get classes and confidences
|
23 |
classes = result.boxes.cls.cpu().numpy().astype(int)
|
24 |
confidences = result.boxes.conf.cpu().numpy()
|
25 |
names = result.names
|
26 |
+
|
27 |
# Count by class
|
28 |
class_counts = {}
|
29 |
for cls, conf in zip(classes, confidences):
|
30 |
cls_name = names[int(cls)]
|
31 |
if cls_name not in class_counts:
|
32 |
class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
|
33 |
+
|
34 |
class_counts[cls_name]["count"] += 1
|
35 |
class_counts[cls_name]["total_confidence"] += float(conf)
|
36 |
class_counts[cls_name]["confidences"].append(float(conf))
|
37 |
+
|
38 |
# Calculate average confidence
|
39 |
for cls_name, stats in class_counts.items():
|
40 |
if stats["count"] > 0:
|
41 |
stats["average_confidence"] = stats["total_confidence"] / stats["count"]
|
42 |
stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
|
43 |
stats.pop("total_confidence") # Remove intermediate calculation
|
44 |
+
|
45 |
# Prepare summary
|
46 |
stats = {
|
47 |
"total_objects": len(classes),
|
48 |
"class_statistics": class_counts,
|
49 |
"average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
|
50 |
}
|
51 |
+
|
52 |
return stats
|
53 |
+
|
54 |
@staticmethod
|
55 |
def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
|
56 |
"""
|
57 |
Generate structured data suitable for visualization
|
58 |
+
|
59 |
Args:
|
60 |
result: Detection result object
|
61 |
class_colors: Dictionary mapping class names to color codes (optional)
|
62 |
+
|
63 |
Returns:
|
64 |
Dictionary with visualization-ready data
|
65 |
"""
|
66 |
if result is None:
|
67 |
return {"error": "No detection result provided"}
|
68 |
+
|
69 |
# Get basic stats first
|
70 |
stats = EvaluationMetrics.calculate_basic_stats(result)
|
71 |
+
|
72 |
# Create visualization-specific data structure
|
73 |
viz_data = {
|
74 |
"total_objects": stats["total_objects"],
|
75 |
"average_confidence": stats["average_confidence"],
|
76 |
"class_data": []
|
77 |
}
|
78 |
+
|
79 |
# Sort classes by count (descending)
|
80 |
sorted_classes = sorted(
|
81 |
stats["class_statistics"].items(),
|
82 |
key=lambda x: x[1]["count"],
|
83 |
reverse=True
|
84 |
)
|
85 |
+
|
86 |
# Create class-specific visualization data
|
87 |
for cls_name, cls_stats in sorted_classes:
|
88 |
class_id = -1
|
|
|
91 |
if name == cls_name:
|
92 |
class_id = idx
|
93 |
break
|
94 |
+
|
95 |
cls_data = {
|
96 |
"name": cls_name,
|
97 |
"class_id": class_id,
|
|
|
100 |
"confidence_std": cls_stats.get("confidence_std", 0),
|
101 |
"color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
|
102 |
}
|
103 |
+
|
104 |
viz_data["class_data"].append(cls_data)
|
105 |
+
|
106 |
return viz_data
|
107 |
+
|
108 |
@staticmethod
|
109 |
def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
|
110 |
"""
|
111 |
Create a horizontal bar chart showing detection statistics
|
112 |
+
|
113 |
Args:
|
114 |
viz_data: Visualization data generated by generate_visualization_data
|
115 |
figsize: Figure size (width, height) in inches
|
116 |
max_classes: Maximum number of classes to display
|
117 |
+
|
118 |
Returns:
|
119 |
Matplotlib figure object
|
120 |
"""
|
|
|
125 |
def create_enhanced_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
|
126 |
"""
|
127 |
Create an enhanced horizontal bar chart with larger fonts and better styling
|
128 |
+
|
129 |
Args:
|
130 |
viz_data: Visualization data dictionary
|
131 |
figsize: Figure size (width, height) in inches
|
132 |
max_classes: Maximum number of classes to display
|
133 |
+
|
134 |
Returns:
|
135 |
Matplotlib figure with enhanced styling
|
136 |
"""
|
137 |
if "error" in viz_data:
|
138 |
# Create empty plot if error
|
139 |
fig, ax = plt.subplots(figsize=figsize)
|
140 |
+
ax.text(0.5, 0.5, viz_data["error"],
|
141 |
ha='center', va='center', fontsize=14, fontfamily='Arial')
|
142 |
ax.set_xlim(0, 1)
|
143 |
ax.set_ylim(0, 1)
|
144 |
ax.axis('off')
|
145 |
return fig
|
146 |
+
|
147 |
if "class_data" not in viz_data or not viz_data["class_data"]:
|
148 |
# Create empty plot if no data
|
149 |
fig, ax = plt.subplots(figsize=figsize)
|
150 |
+
ax.text(0.5, 0.5, "No detection data available",
|
151 |
ha='center', va='center', fontsize=14, fontfamily='Arial')
|
152 |
ax.set_xlim(0, 1)
|
153 |
ax.set_ylim(0, 1)
|
154 |
ax.axis('off')
|
155 |
return fig
|
156 |
+
|
157 |
# Limit to max_classes
|
158 |
class_data = viz_data["class_data"][:max_classes]
|
159 |
+
|
160 |
# Extract data for plotting
|
161 |
class_names = [item["name"] for item in class_data]
|
162 |
counts = [item["count"] for item in class_data]
|
163 |
colors = [item["color"] for item in class_data]
|
164 |
+
|
165 |
# Create figure and horizontal bar chart with improved styling
|
166 |
plt.rcParams['font.family'] = 'Arial'
|
167 |
fig, ax = plt.subplots(figsize=figsize)
|
168 |
+
|
169 |
# Set background color to white
|
170 |
fig.patch.set_facecolor('white')
|
171 |
ax.set_facecolor('white')
|
172 |
+
|
173 |
y_pos = np.arange(len(class_names))
|
174 |
+
|
175 |
# Create horizontal bars with class-specific colors
|
176 |
bars = ax.barh(y_pos, counts, color=colors, alpha=0.8, height=0.6)
|
177 |
+
|
178 |
# Add count values at end of each bar with larger font
|
179 |
for i, bar in enumerate(bars):
|
180 |
width = bar.get_width()
|
181 |
conf = class_data[i]["average_confidence"]
|
182 |
+
ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
|
183 |
+
f"{width:.0f} (conf: {conf:.2f})",
|
184 |
va='center', fontsize=12, fontfamily='Arial')
|
185 |
+
|
186 |
# Customize axis and labels with larger fonts
|
187 |
ax.set_yticks(y_pos)
|
188 |
ax.set_yticklabels(class_names, fontsize=14, fontfamily='Arial')
|
189 |
ax.invert_yaxis() # Labels read top-to-bottom
|
190 |
ax.set_xlabel('Count', fontsize=14, fontfamily='Arial')
|
191 |
+
ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total',
|
192 |
fontsize=16, fontfamily='Arial', fontweight='bold')
|
193 |
+
|
194 |
# Add grid for better readability
|
195 |
ax.set_axisbelow(True)
|
196 |
ax.grid(axis='x', linestyle='--', alpha=0.7, color='#E5E7EB')
|
197 |
+
|
198 |
# Increase tick label font size
|
199 |
ax.tick_params(axis='both', which='major', labelsize=12)
|
200 |
+
|
201 |
# Add detection summary as a text box with improved styling
|
202 |
summary_text = (
|
203 |
f"Total Objects: {viz_data['total_objects']}\n"
|
|
|
205 |
f"Unique Classes: {len(viz_data['class_data'])}"
|
206 |
)
|
207 |
plt.figtext(0.02, 0.02, summary_text, fontsize=12, fontfamily='Arial',
|
208 |
+
bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.5',
|
209 |
edgecolor='#E5E7EB'))
|
210 |
+
|
211 |
plt.tight_layout()
|
212 |
return fig
|
213 |
+
|
214 |
@staticmethod
|
215 |
def format_detection_summary(viz_data: Dict) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
if "error" in viz_data:
|
217 |
return viz_data["error"]
|
218 |
+
|
219 |
if "total_objects" not in viz_data:
|
220 |
return "No detection data available."
|
221 |
+
|
|
|
222 |
total_objects = viz_data["total_objects"]
|
223 |
avg_confidence = viz_data["average_confidence"]
|
224 |
+
|
|
|
225 |
lines = [
|
226 |
+
f"Detected {total_objects} objects.",
|
227 |
f"Average confidence: {avg_confidence:.2f}",
|
228 |
+
"Objects by class:"
|
|
|
229 |
]
|
230 |
+
|
|
|
231 |
if "class_data" in viz_data and viz_data["class_data"]:
|
232 |
for item in viz_data["class_data"]:
|
233 |
count = item['count']
|
|
|
234 |
item_text = "item" if count == 1 else "items"
|
235 |
+
lines.append(f"• {item['name']}: {count} {item_text} (Confidence: {item['average_confidence']:.2f})")
|
|
|
|
|
|
|
|
|
236 |
else:
|
237 |
+
lines.append("No class information available.")
|
238 |
+
|
239 |
return "\n".join(lines)
|
240 |
+
|
241 |
@staticmethod
|
242 |
def calculate_distance_metrics(result: Any) -> Dict:
|
243 |
"""
|
244 |
Calculate distance-related metrics for detected objects
|
245 |
+
|
246 |
Args:
|
247 |
result: Detection result object
|
248 |
+
|
249 |
Returns:
|
250 |
Dictionary with distance metrics
|
251 |
"""
|
252 |
if result is None:
|
253 |
return {"error": "No detection result provided"}
|
254 |
+
|
255 |
boxes = result.boxes.xyxy.cpu().numpy()
|
256 |
classes = result.boxes.cls.cpu().numpy().astype(int)
|
257 |
names = result.names
|
258 |
+
|
259 |
# Initialize metrics
|
260 |
metrics = {
|
261 |
"proximity": {}, # Classes that appear close to each other
|
262 |
"spatial_distribution": {}, # Distribution across the image
|
263 |
"size_distribution": {} # Size distribution of objects
|
264 |
}
|
265 |
+
|
266 |
# Calculate image dimensions (assuming normalized coordinates or extract from result)
|
267 |
img_width, img_height = 1, 1
|
268 |
if hasattr(result, "orig_shape"):
|
269 |
img_height, img_width = result.orig_shape[:2]
|
270 |
+
|
271 |
# Calculate bounding box areas and centers
|
272 |
areas = []
|
273 |
centers = []
|
274 |
class_names = []
|
275 |
+
|
276 |
for box, cls in zip(boxes, classes):
|
277 |
x1, y1, x2, y2 = box
|
278 |
width, height = x2 - x1, y2 - y1
|
279 |
area = width * height
|
280 |
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
281 |
+
|
282 |
areas.append(area)
|
283 |
centers.append((center_x, center_y))
|
284 |
class_names.append(names[int(cls)])
|
285 |
+
|
286 |
# Calculate spatial distribution
|
287 |
if centers:
|
288 |
x_coords = [c[0] for c in centers]
|
289 |
y_coords = [c[1] for c in centers]
|
290 |
+
|
291 |
metrics["spatial_distribution"] = {
|
292 |
"x_mean": float(np.mean(x_coords)) / img_width,
|
293 |
"y_mean": float(np.mean(y_coords)) / img_height,
|
294 |
"x_std": float(np.std(x_coords)) / img_width,
|
295 |
"y_std": float(np.std(y_coords)) / img_height
|
296 |
}
|
297 |
+
|
298 |
# Calculate size distribution
|
299 |
if areas:
|
300 |
metrics["size_distribution"] = {
|
|
|
303 |
"min_area": float(np.min(areas)) / (img_width * img_height),
|
304 |
"max_area": float(np.max(areas)) / (img_width * img_height)
|
305 |
}
|
306 |
+
|
307 |
# Calculate proximity between different classes
|
308 |
class_centers = {}
|
309 |
for cls_name, center in zip(class_names, centers):
|
310 |
if cls_name not in class_centers:
|
311 |
class_centers[cls_name] = []
|
312 |
class_centers[cls_name].append(center)
|
313 |
+
|
314 |
# Find classes that appear close to each other
|
315 |
proximity_pairs = []
|
316 |
for i, cls1 in enumerate(class_centers.keys()):
|
317 |
for j, cls2 in enumerate(class_centers.keys()):
|
318 |
if i >= j: # Avoid duplicate pairs and self-comparison
|
319 |
continue
|
320 |
+
|
321 |
# Calculate minimum distance between any two objects of these classes
|
322 |
min_distance = float('inf')
|
323 |
for center1 in class_centers[cls1]:
|
324 |
for center2 in class_centers[cls2]:
|
325 |
dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
|
326 |
min_distance = min(min_distance, dist)
|
327 |
+
|
328 |
# Normalize by image diagonal
|
329 |
img_diagonal = np.sqrt(img_width**2 + img_height**2)
|
330 |
norm_distance = min_distance / img_diagonal
|
331 |
+
|
332 |
proximity_pairs.append({
|
333 |
"class1": cls1,
|
334 |
"class2": cls2,
|
335 |
"distance": float(norm_distance)
|
336 |
})
|
337 |
+
|
338 |
# Sort by distance and keep the closest pairs
|
339 |
proximity_pairs.sort(key=lambda x: x["distance"])
|
340 |
metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
|
341 |
+
|
342 |
return metrics
|
visualization_helper.py
CHANGED
@@ -1,34 +1,35 @@
|
|
1 |
import cv2
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
|
|
4 |
from typing import Any, List, Dict, Tuple, Optional
|
5 |
import io
|
6 |
from PIL import Image
|
7 |
|
8 |
class VisualizationHelper:
|
9 |
"""Helper class for visualizing detection results"""
|
10 |
-
|
11 |
@staticmethod
|
12 |
def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
|
13 |
figsize: Tuple[int, int] = (12, 12),
|
14 |
return_pil: bool = False) -> Optional[Image.Image]:
|
15 |
"""
|
16 |
Visualize detection results on a single image
|
17 |
-
|
18 |
Args:
|
19 |
image: Image path or numpy array
|
20 |
result: Detection result object
|
21 |
color_mapper: ColorMapper instance for consistent colors
|
22 |
figsize: Figure size
|
23 |
return_pil: If True, returns a PIL Image object
|
24 |
-
|
25 |
Returns:
|
26 |
PIL Image if return_pil is True, otherwise displays the plot
|
27 |
"""
|
28 |
if result is None:
|
29 |
print('No data for visualization')
|
30 |
return None
|
31 |
-
|
32 |
# Read image if path is provided
|
33 |
if isinstance(image, str):
|
34 |
img = cv2.imread(image)
|
@@ -40,19 +41,19 @@ class VisualizationHelper:
|
|
40 |
if isinstance(img, np.ndarray):
|
41 |
# Assuming BGR format from OpenCV
|
42 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
43 |
-
|
44 |
# Create figure
|
45 |
fig, ax = plt.subplots(figsize=figsize)
|
46 |
ax.imshow(img)
|
47 |
-
|
48 |
# Get bounding boxes, classes and confidences
|
49 |
boxes = result.boxes.xyxy.cpu().numpy()
|
50 |
classes = result.boxes.cls.cpu().numpy()
|
51 |
confs = result.boxes.conf.cpu().numpy()
|
52 |
-
|
53 |
# Get class names
|
54 |
names = result.names
|
55 |
-
|
56 |
# Create a default color mapper if none is provided
|
57 |
if color_mapper is None:
|
58 |
# For backward compatibility, fallback to a simple color function
|
@@ -67,29 +68,37 @@ class VisualizationHelper:
|
|
67 |
# Convert hex to RGB float values for matplotlib
|
68 |
hex_color = hex_color.lstrip('#')
|
69 |
return tuple(int(hex_color[i:i+2], 16) / 255 for i in (0, 2, 4)) + (1.0,)
|
70 |
-
|
71 |
# Draw detection results
|
72 |
for box, cls, conf in zip(boxes, classes, confs):
|
73 |
x1, y1, x2, y2 = box
|
74 |
cls_id = int(cls)
|
75 |
cls_name = names[cls_id]
|
76 |
-
|
77 |
# Get color for this class
|
78 |
box_color = get_color(cls_id)
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# Add bounding box
|
86 |
-
ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
|
87 |
fill=False, edgecolor=box_color[:3], linewidth=2))
|
88 |
-
|
89 |
ax.axis('off')
|
90 |
# ax.set_title('Detection Result')
|
91 |
plt.tight_layout()
|
92 |
-
|
93 |
if return_pil:
|
94 |
# Convert plot to PIL Image
|
95 |
buf = io.BytesIO()
|
@@ -101,47 +110,47 @@ class VisualizationHelper:
|
|
101 |
else:
|
102 |
plt.show()
|
103 |
return None
|
104 |
-
|
105 |
@staticmethod
|
106 |
def create_summary(result: Any) -> Dict:
|
107 |
"""
|
108 |
Create a summary of detection results
|
109 |
-
|
110 |
Args:
|
111 |
result: Detection result object
|
112 |
-
|
113 |
Returns:
|
114 |
Dictionary with detection summary statistics
|
115 |
"""
|
116 |
if result is None:
|
117 |
return {"error": "No detection result provided"}
|
118 |
-
|
119 |
# Get classes and confidences
|
120 |
classes = result.boxes.cls.cpu().numpy().astype(int)
|
121 |
confidences = result.boxes.conf.cpu().numpy()
|
122 |
names = result.names
|
123 |
-
|
124 |
# Count detections by class
|
125 |
class_counts = {}
|
126 |
for cls, conf in zip(classes, confidences):
|
127 |
cls_name = names[int(cls)]
|
128 |
if cls_name not in class_counts:
|
129 |
class_counts[cls_name] = {"count": 0, "confidences": []}
|
130 |
-
|
131 |
class_counts[cls_name]["count"] += 1
|
132 |
class_counts[cls_name]["confidences"].append(float(conf))
|
133 |
-
|
134 |
# Calculate average confidence for each class
|
135 |
for cls_name, stats in class_counts.items():
|
136 |
if stats["confidences"]:
|
137 |
stats["average_confidence"] = float(np.mean(stats["confidences"]))
|
138 |
stats.pop("confidences") # Remove detailed confidences list to keep summary concise
|
139 |
-
|
140 |
# Prepare summary
|
141 |
summary = {
|
142 |
"total_objects": len(classes),
|
143 |
"class_counts": class_counts,
|
144 |
"unique_classes": len(class_counts)
|
145 |
}
|
146 |
-
|
147 |
return summary
|
|
|
1 |
import cv2
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib.patheffects as path_effects
|
5 |
from typing import Any, List, Dict, Tuple, Optional
|
6 |
import io
|
7 |
from PIL import Image
|
8 |
|
9 |
class VisualizationHelper:
|
10 |
"""Helper class for visualizing detection results"""
|
11 |
+
|
12 |
@staticmethod
|
13 |
def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
|
14 |
figsize: Tuple[int, int] = (12, 12),
|
15 |
return_pil: bool = False) -> Optional[Image.Image]:
|
16 |
"""
|
17 |
Visualize detection results on a single image
|
18 |
+
|
19 |
Args:
|
20 |
image: Image path or numpy array
|
21 |
result: Detection result object
|
22 |
color_mapper: ColorMapper instance for consistent colors
|
23 |
figsize: Figure size
|
24 |
return_pil: If True, returns a PIL Image object
|
25 |
+
|
26 |
Returns:
|
27 |
PIL Image if return_pil is True, otherwise displays the plot
|
28 |
"""
|
29 |
if result is None:
|
30 |
print('No data for visualization')
|
31 |
return None
|
32 |
+
|
33 |
# Read image if path is provided
|
34 |
if isinstance(image, str):
|
35 |
img = cv2.imread(image)
|
|
|
41 |
if isinstance(img, np.ndarray):
|
42 |
# Assuming BGR format from OpenCV
|
43 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
44 |
+
|
45 |
# Create figure
|
46 |
fig, ax = plt.subplots(figsize=figsize)
|
47 |
ax.imshow(img)
|
48 |
+
|
49 |
# Get bounding boxes, classes and confidences
|
50 |
boxes = result.boxes.xyxy.cpu().numpy()
|
51 |
classes = result.boxes.cls.cpu().numpy()
|
52 |
confs = result.boxes.conf.cpu().numpy()
|
53 |
+
|
54 |
# Get class names
|
55 |
names = result.names
|
56 |
+
|
57 |
# Create a default color mapper if none is provided
|
58 |
if color_mapper is None:
|
59 |
# For backward compatibility, fallback to a simple color function
|
|
|
68 |
# Convert hex to RGB float values for matplotlib
|
69 |
hex_color = hex_color.lstrip('#')
|
70 |
return tuple(int(hex_color[i:i+2], 16) / 255 for i in (0, 2, 4)) + (1.0,)
|
71 |
+
|
72 |
# Draw detection results
|
73 |
for box, cls, conf in zip(boxes, classes, confs):
|
74 |
x1, y1, x2, y2 = box
|
75 |
cls_id = int(cls)
|
76 |
cls_name = names[cls_id]
|
77 |
+
|
78 |
# Get color for this class
|
79 |
box_color = get_color(cls_id)
|
80 |
+
|
81 |
+
box_width = x2 - x1
|
82 |
+
box_height = y2 - y1
|
83 |
+
box_area = box_width * box_height
|
84 |
+
|
85 |
+
# 根據框大小調整字體大小,但有限制
|
86 |
+
adaptive_fontsize = max(10, min(14, int(10 + box_area / 10000)))
|
87 |
+
|
88 |
+
|
89 |
+
ax.text(x1, y1 - 8, f'{cls_name}: {conf:.2f}',
|
90 |
+
color='white', fontsize=adaptive_fontsize, fontweight="bold",
|
91 |
+
bbox=dict(facecolor=box_color[:3], alpha=0.85, pad=3, boxstyle="round,pad=0.3"),
|
92 |
+
path_effects=[path_effects.withStroke(linewidth=1.5, foreground="black")])
|
93 |
+
|
94 |
# Add bounding box
|
95 |
+
ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
|
96 |
fill=False, edgecolor=box_color[:3], linewidth=2))
|
97 |
+
|
98 |
ax.axis('off')
|
99 |
# ax.set_title('Detection Result')
|
100 |
plt.tight_layout()
|
101 |
+
|
102 |
if return_pil:
|
103 |
# Convert plot to PIL Image
|
104 |
buf = io.BytesIO()
|
|
|
110 |
else:
|
111 |
plt.show()
|
112 |
return None
|
113 |
+
|
114 |
@staticmethod
|
115 |
def create_summary(result: Any) -> Dict:
|
116 |
"""
|
117 |
Create a summary of detection results
|
118 |
+
|
119 |
Args:
|
120 |
result: Detection result object
|
121 |
+
|
122 |
Returns:
|
123 |
Dictionary with detection summary statistics
|
124 |
"""
|
125 |
if result is None:
|
126 |
return {"error": "No detection result provided"}
|
127 |
+
|
128 |
# Get classes and confidences
|
129 |
classes = result.boxes.cls.cpu().numpy().astype(int)
|
130 |
confidences = result.boxes.conf.cpu().numpy()
|
131 |
names = result.names
|
132 |
+
|
133 |
# Count detections by class
|
134 |
class_counts = {}
|
135 |
for cls, conf in zip(classes, confidences):
|
136 |
cls_name = names[int(cls)]
|
137 |
if cls_name not in class_counts:
|
138 |
class_counts[cls_name] = {"count": 0, "confidences": []}
|
139 |
+
|
140 |
class_counts[cls_name]["count"] += 1
|
141 |
class_counts[cls_name]["confidences"].append(float(conf))
|
142 |
+
|
143 |
# Calculate average confidence for each class
|
144 |
for cls_name, stats in class_counts.items():
|
145 |
if stats["confidences"]:
|
146 |
stats["average_confidence"] = float(np.mean(stats["confidences"]))
|
147 |
stats.pop("confidences") # Remove detailed confidences list to keep summary concise
|
148 |
+
|
149 |
# Prepare summary
|
150 |
summary = {
|
151 |
"total_objects": len(classes),
|
152 |
"class_counts": class_counts,
|
153 |
"unique_classes": len(class_counts)
|
154 |
}
|
155 |
+
|
156 |
return summary
|