DawnC commited on
Commit
1487b33
·
verified ·
1 Parent(s): 5888da9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +110 -401
  2. image_processor.py +336 -0
app.py CHANGED
@@ -1,309 +1,111 @@
1
  import os
2
  import numpy as np
3
- import torch
4
- import cv2
5
  import matplotlib.pyplot as plt
6
  import gradio as gr
7
- import io
8
- from PIL import Image, ImageDraw, ImageFont
9
- import spaces
10
  from typing import Dict, List, Any, Optional, Tuple
11
- from ultralytics import YOLO
12
 
13
  from detection_model import DetectionModel
14
  from color_mapper import ColorMapper
15
- from visualization_helper import VisualizationHelper
16
  from evaluation_metrics import EvaluationMetrics
17
  from style import Style
 
18
 
19
-
20
- color_mapper = ColorMapper()
21
- model_instances = {}
22
-
23
- @spaces.GPU
24
- def process_image(image, model_instance, confidence_threshold, filter_classes=None):
25
- """
26
- Process an image for object detection
27
-
28
- Args:
29
- image: Input image (numpy array or PIL Image)
30
- model_instance: DetectionModel instance to use
31
- confidence_threshold: Confidence threshold for detection
32
- filter_classes: Optional list of classes to filter results
33
-
34
- Returns:
35
- Tuple of (result_image, result_text, stats_data)
36
- """
37
- # initialize key variables
38
- result = None
39
- stats = {}
40
- temp_path = None
41
-
42
- try:
43
- # update confidence threshold
44
- model_instance.confidence = confidence_threshold
45
-
46
- # processing input image
47
- if isinstance(image, np.ndarray):
48
- # Convert BGR to RGB if needed
49
- if image.shape[2] == 3:
50
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
- else:
52
- image_rgb = image
53
- pil_image = Image.fromarray(image_rgb)
54
- elif image is None:
55
- return None, "No image provided. Please upload an image.", {}
56
- else:
57
- pil_image = image
58
-
59
- # store temp files
60
- import uuid
61
- import tempfile
62
-
63
- temp_dir = tempfile.gettempdir() # use system temp directory
64
- temp_filename = f"temp_{uuid.uuid4().hex}.jpg"
65
- temp_path = os.path.join(temp_dir, temp_filename)
66
- pil_image.save(temp_path)
67
-
68
- # object detection
69
- result = model_instance.detect(temp_path)
70
-
71
- if result is None:
72
- return None, "Detection failed. Please try again with a different image.", {}
73
-
74
- # calculate stats
75
- stats = EvaluationMetrics.calculate_basic_stats(result)
76
-
77
- # add space calculation
78
- spatial_metrics = EvaluationMetrics.calculate_distance_metrics(result)
79
- stats["spatial_metrics"] = spatial_metrics
80
-
81
- if filter_classes and len(filter_classes) > 0:
82
- # get classes, boxes, confidence
83
- classes = result.boxes.cls.cpu().numpy().astype(int)
84
- confs = result.boxes.conf.cpu().numpy()
85
- boxes = result.boxes.xyxy.cpu().numpy()
86
-
87
- mask = np.zeros_like(classes, dtype=bool)
88
- for cls_id in filter_classes:
89
- mask = np.logical_or(mask, classes == cls_id)
90
-
91
- filtered_stats = {
92
- "total_objects": int(np.sum(mask)),
93
- "class_statistics": {},
94
- "average_confidence": float(np.mean(confs[mask])) if np.any(mask) else 0,
95
- "spatial_metrics": stats["spatial_metrics"]
96
- }
97
-
98
- # update stats
99
- names = result.names
100
- for cls, conf in zip(classes[mask], confs[mask]):
101
- cls_name = names[int(cls)]
102
- if cls_name not in filtered_stats["class_statistics"]:
103
- filtered_stats["class_statistics"][cls_name] = {
104
- "count": 0,
105
- "average_confidence": 0
106
- }
107
-
108
- filtered_stats["class_statistics"][cls_name]["count"] += 1
109
- filtered_stats["class_statistics"][cls_name]["average_confidence"] = conf
110
-
111
- stats = filtered_stats
112
-
113
- viz_data = EvaluationMetrics.generate_visualization_data(
114
- result,
115
- color_mapper.get_all_colors()
116
- )
117
-
118
- result_image = VisualizationHelper.visualize_detection(
119
- temp_path, result, color_mapper=color_mapper, figsize=(12, 12), return_pil=True
120
- )
121
-
122
- result_text = EvaluationMetrics.format_detection_summary(viz_data)
123
-
124
- return result_image, result_text, stats
125
-
126
- except Exception as e:
127
- error_message = f"Error Occurs: {str(e)}"
128
- import traceback
129
- traceback.print_exc()
130
- print(error_message)
131
- return None, error_message, {}
132
-
133
- finally:
134
- if temp_path and os.path.exists(temp_path):
135
- try:
136
- os.remove(temp_path)
137
- except Exception as e:
138
- print(f"Cannot delete temp files {temp_path}: {str(e)}")
139
-
140
- def format_result_text(stats):
141
- """
142
- Format detection statistics into readable text with improved spacing
143
-
144
- Args:
145
- stats: Dictionary containing detection statistics
146
-
147
- Returns:
148
- Formatted text summary
149
- """
150
- if not stats or "total_objects" not in stats:
151
- return "No objects detected."
152
-
153
- # 減少不必要的空行
154
- lines = [
155
- f"Detected {stats['total_objects']} objects.",
156
- f"Average confidence: {stats.get('average_confidence', 0):.2f}",
157
- "Objects by class:"
158
- ]
159
-
160
- if "class_statistics" in stats and stats["class_statistics"]:
161
- # 按計數排序類別
162
- sorted_classes = sorted(
163
- stats["class_statistics"].items(),
164
- key=lambda x: x[1]["count"],
165
- reverse=True
166
- )
167
-
168
- for cls_name, cls_stats in sorted_classes:
169
- count = cls_stats["count"]
170
- conf = cls_stats.get("average_confidence", 0)
171
-
172
- item_text = "item" if count == 1 else "items"
173
- lines.append(f"• {cls_name}: {count} {item_text} (avg conf: {conf:.2f})")
174
- else:
175
- lines.append("No class information available.")
176
-
177
- # 添加空間信息
178
- if "spatial_metrics" in stats and "spatial_distribution" in stats["spatial_metrics"]:
179
- lines.append("Object Distribution:")
180
-
181
- dist = stats["spatial_metrics"]["spatial_distribution"]
182
- x_mean = dist.get("x_mean", 0)
183
- y_mean = dist.get("y_mean", 0)
184
-
185
- # 描述物體的大致位置
186
- if x_mean < 0.33:
187
- h_pos = "on the left side"
188
- elif x_mean < 0.67:
189
- h_pos = "in the center"
190
- else:
191
- h_pos = "on the right side"
192
-
193
- if y_mean < 0.33:
194
- v_pos = "in the upper part"
195
- elif y_mean < 0.67:
196
- v_pos = "in the middle"
197
- else:
198
- v_pos = "in the lower part"
199
-
200
- lines.append(f"• Most objects appear {h_pos} {v_pos} of the image")
201
-
202
- return "\n".join(lines)
203
-
204
- def format_json_for_display(stats):
205
- """
206
- Format statistics JSON for better display
207
-
208
- Args:
209
- stats: Raw statistics dictionary
210
-
211
- Returns:
212
- Formatted statistics structure for display
213
- """
214
- # Create a cleaner copy of the stats for display
215
- display_stats = {}
216
-
217
- # Add summary section
218
- display_stats["summary"] = {
219
- "total_objects": stats.get("total_objects", 0),
220
- "average_confidence": round(stats.get("average_confidence", 0), 3)
221
- }
222
-
223
- # Add class statistics in a more organized way
224
- if "class_statistics" in stats and stats["class_statistics"]:
225
- # Sort classes by count (descending)
226
- sorted_classes = sorted(
227
- stats["class_statistics"].items(),
228
- key=lambda x: x[1].get("count", 0),
229
- reverse=True
230
- )
231
-
232
- class_stats = {}
233
- for cls_name, cls_data in sorted_classes:
234
- class_stats[cls_name] = {
235
- "count": cls_data.get("count", 0),
236
- "average_confidence": round(cls_data.get("average_confidence", 0), 3)
237
- }
238
-
239
- display_stats["detected_objects"] = class_stats
240
-
241
- # Simplify spatial metrics
242
- if "spatial_metrics" in stats:
243
- spatial = stats["spatial_metrics"]
244
-
245
- # Simplify spatial distribution
246
- if "spatial_distribution" in spatial:
247
- dist = spatial["spatial_distribution"]
248
- display_stats["spatial"] = {
249
- "distribution": {
250
- "x_mean": round(dist.get("x_mean", 0), 3),
251
- "y_mean": round(dist.get("y_mean", 0), 3),
252
- "x_std": round(dist.get("x_std", 0), 3),
253
- "y_std": round(dist.get("y_std", 0), 3)
254
- }
255
- }
256
-
257
- # Add simplified size information
258
- if "size_distribution" in spatial:
259
- size = spatial["size_distribution"]
260
- display_stats["spatial"]["size"] = {
261
- "mean_area": round(size.get("mean_area", 0), 3),
262
- "min_area": round(size.get("min_area", 0), 3),
263
- "max_area": round(size.get("max_area", 0), 3)
264
- }
265
-
266
- return display_stats
267
 
268
  def get_all_classes():
269
  """
270
  Get all available COCO classes from the currently active model or fallback to standard COCO classes
271
-
272
  Returns:
273
  List of tuples (class_id, class_name)
274
  """
275
- global model_instances
276
-
277
  # Try to get class names from any loaded model
278
- for model_name, model_instance in model_instances.items():
279
  if model_instance and model_instance.is_model_loaded:
280
  try:
281
  class_names = model_instance.class_names
282
  return [(idx, name) for idx, name in class_names.items()]
283
  except Exception:
284
  pass
285
-
286
  # Fallback to standard COCO classes
287
  return [
288
  (0, 'person'), (1, 'bicycle'), (2, 'car'), (3, 'motorcycle'), (4, 'airplane'),
289
  (5, 'bus'), (6, 'train'), (7, 'truck'), (8, 'boat'), (9, 'traffic light'),
290
- (10, 'fire hydrant'), (11, 'stop sign'), (12, 'parking meter'), (13, 'bench'),
291
  (14, 'bird'), (15, 'cat'), (16, 'dog'), (17, 'horse'), (18, 'sheep'), (19, 'cow'),
292
  (20, 'elephant'), (21, 'bear'), (22, 'zebra'), (23, 'giraffe'), (24, 'backpack'),
293
  (25, 'umbrella'), (26, 'handbag'), (27, 'tie'), (28, 'suitcase'), (29, 'frisbee'),
294
  (30, 'skis'), (31, 'snowboard'), (32, 'sports ball'), (33, 'kite'), (34, 'baseball bat'),
295
- (35, 'baseball glove'), (36, 'skateboard'), (37, 'surfboard'), (38, 'tennis racket'),
296
  (39, 'bottle'), (40, 'wine glass'), (41, 'cup'), (42, 'fork'), (43, 'knife'),
297
- (44, 'spoon'), (45, 'bowl'), (46, 'banana'), (47, 'apple'), (48, 'sandwich'),
298
  (49, 'orange'), (50, 'broccoli'), (51, 'carrot'), (52, 'hot dog'), (53, 'pizza'),
299
- (54, 'donut'), (55, 'cake'), (56, 'chair'), (57, 'couch'), (58, 'potted plant'),
300
  (59, 'bed'), (60, 'dining table'), (61, 'toilet'), (62, 'tv'), (63, 'laptop'),
301
- (64, 'mouse'), (65, 'remote'), (66, 'keyboard'), (67, 'cell phone'), (68, 'microwave'),
302
  (69, 'oven'), (70, 'toaster'), (71, 'sink'), (72, 'refrigerator'), (73, 'book'),
303
- (74, 'clock'), (75, 'vase'), (76, 'scissors'), (77, 'teddy bear'), (78, 'hair drier'),
304
  (79, 'toothbrush')
305
  ]
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def create_interface():
308
  """創建 Gradio 界面,包含美化的視覺效果"""
309
  css = Style.get_css()
@@ -312,11 +114,11 @@ def create_interface():
312
  available_models = DetectionModel.get_available_models()
313
  model_choices = [model["model_file"] for model in available_models]
314
  model_labels = [f"{model['name']} - {model['inference_speed']}" for model in available_models]
315
-
316
  # 可用類別過濾選項
317
  available_classes = get_all_classes()
318
  class_choices = [f"{id}: {name}" for id, name in available_classes]
319
-
320
  # 創建 Gradio Blocks 界面
321
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="teal", secondary_hue="blue")) as demo:
322
  # 頁面頂部標題
@@ -330,36 +132,36 @@ def create_interface():
330
  """)
331
 
332
  current_model = gr.State("yolov8m.pt") # use medium size model as defualt
333
-
334
- # 主要內容區 - 輸入和輸出面板
335
  with gr.Row(equal_height=True):
336
  # 左側 - 輸入控制區(可上傳圖片)
337
  with gr.Column(scale=4, elem_classes="input-panel"):
338
  with gr.Group():
339
  gr.HTML('<div class="section-heading">Upload Image</div>')
340
  image_input = gr.Image(type="pil", label="Upload an image", elem_classes="upload-box")
341
-
342
  with gr.Accordion("Advanced Settings", open=False):
343
  with gr.Row():
344
  model_dropdown = gr.Dropdown(
345
  choices=model_choices,
346
- value="yolov8m.pt",
347
  label="Select Model",
348
  info="Choose different models based on your needs for speed vs. accuracy"
349
  )
350
-
351
  # display model info
352
  model_info = gr.Markdown(DetectionModel.get_model_description("yolov8m.pt"))
353
 
354
  confidence = gr.Slider(
355
- minimum=0.1,
356
- maximum=0.9,
357
- value=0.25,
358
- step=0.05,
359
  label="Confidence Threshold",
360
  info="Higher values show fewer but more confident detections"
361
  )
362
-
363
  with gr.Accordion("Filter Classes", open=False):
364
  # 常見物件類別快速選擇按鈕
365
  gr.HTML('<div class="section-heading" style="font-size: 1rem;">Common Categories</div>')
@@ -368,7 +170,7 @@ def create_interface():
368
  vehicles_btn = gr.Button("Vehicles", size="sm")
369
  animals_btn = gr.Button("Animals", size="sm")
370
  objects_btn = gr.Button("Common Objects", size="sm")
371
-
372
  # 類別選擇下拉框
373
  class_filter = gr.Dropdown(
374
  choices=class_choices,
@@ -376,10 +178,10 @@ def create_interface():
376
  label="Select Classes to Display",
377
  info="Leave empty to show all detected objects"
378
  )
379
-
380
  # detect buttom
381
  detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="detect-btn")
382
-
383
  # 使用說明區
384
  with gr.Group(elem_classes="how-to-use"):
385
  gr.HTML('<div class="section-heading">How to Use</div>')
@@ -388,19 +190,19 @@ def create_interface():
388
  2. (Optional) Adjust settings like confidence threshold or model size (n, m, x)
389
  3. Optionally filter to specific object classes
390
  4. Click "Detect Objects" button
391
-
392
  The model will identify objects in your image and display them with bounding boxes.
393
-
394
  **Note:** Detection quality depends on image clarity and model settings.
395
  """)
396
-
397
  # 右側 - 結果顯示區
398
  with gr.Column(scale=6, elem_classes="output-panel"):
399
  with gr.Tabs(elem_classes="tabs"):
400
  with gr.Tab("Detection Result"):
401
  result_image = gr.Image(type="pil", label="Detection Result")
402
-
403
- # 文本框的格式
404
  with gr.Group(elem_classes="result-details-box"):
405
  gr.HTML('<div class="section-heading">Detection Details</div>')
406
  # 文本框設置,讓顯示會更寬
@@ -410,20 +212,20 @@ def create_interface():
410
  max_lines=15,
411
  elem_classes="wide-result-text",
412
  elem_id="detection-details",
413
- container=False,
414
- scale=2,
415
- min_width=600
416
  )
417
-
418
  with gr.Tab("Statistics"):
419
  with gr.Row():
420
  with gr.Column(scale=3, elem_classes="plot-column"):
421
  gr.HTML('<div class="section-heading">Object Distribution</div>')
422
  plot_output = gr.Plot(
423
- label=None,
424
  elem_classes="large-plot-container"
425
  )
426
-
427
  # 右側放 JSON 數據比較清晰
428
  with gr.Column(scale=2, elem_classes="stats-column"):
429
  gr.HTML('<div class="section-heading">Detection Statistics</div>')
@@ -431,9 +233,9 @@ def create_interface():
431
  label=None, # remove label
432
  elem_classes="enhanced-json-display"
433
  )
434
-
435
  detect_btn.click(
436
- fn=lambda img, model, conf, classes: process_and_plot(img, model, conf, classes),
437
  inputs=[image_input, current_model, confidence, class_filter],
438
  outputs=[result_image, result_text, stats_json, plot_output]
439
  )
@@ -444,155 +246,62 @@ def create_interface():
444
  inputs=[model_dropdown],
445
  outputs=[current_model, model_info]
446
  )
447
-
448
  # each classes link
449
  people_classes = [0] # 人
450
  vehicles_classes = [1, 2, 3, 4, 5, 6, 7, 8] # 各種車輛
451
  animals_classes = list(range(14, 24)) # COCO 中的動物
452
  common_objects = [41, 42, 43, 44, 45, 67, 73, 74, 76] # 常見家居物品
453
-
454
  # Linked the quik buttom
455
  people_btn.click(
456
  lambda: [f"{id}: {name}" for id, name in available_classes if id in people_classes],
457
  outputs=class_filter
458
  )
459
-
460
  vehicles_btn.click(
461
  lambda: [f"{id}: {name}" for id, name in available_classes if id in vehicles_classes],
462
  outputs=class_filter
463
  )
464
-
465
  animals_btn.click(
466
  lambda: [f"{id}: {name}" for id, name in available_classes if id in animals_classes],
467
  outputs=class_filter
468
  )
469
-
470
  objects_btn.click(
471
  lambda: [f"{id}: {name}" for id, name in available_classes if id in common_objects],
472
  outputs=class_filter
473
  )
474
-
475
  example_images = [
476
  "room_01.jpg",
477
  "street_01.jpg",
478
  "street_02.jpg",
479
  "street_03.jpg"
480
  ]
481
-
482
  # add example images
483
  gr.Examples(
484
  examples=example_images,
485
  inputs=image_input,
486
- outputs=None,
487
- fn=None,
488
- cache_examples=False,
489
  )
490
-
491
- # 頁腳部分
492
  gr.HTML("""
493
  <div class="footer">
494
  <p>Powered by YOLOv8 and Ultralytics • Created with Gradio</p>
495
  <p>Model can detect 80 different classes of objects</p>
496
  </div>
497
  """)
498
-
499
- return demo
500
-
501
- @spaces.GPU
502
- def process_and_plot(image, model_name, confidence_threshold, filter_classes=None):
503
- """
504
- Process image and create plots for statistics with enhanced visualization
505
-
506
- Args:
507
- image: Input image
508
- model_name: Name of the model to use
509
- confidence_threshold: Confidence threshold for detection
510
- filter_classes: Optional list of classes to filter results
511
-
512
- Returns:
513
- Tuple of (result_image, result_text, formatted_stats, plot_figure)
514
- """
515
- global model_instances
516
-
517
- if model_name not in model_instances:
518
- print(f"Creating new model instance for {model_name}")
519
- model_instances[model_name] = DetectionModel(model_name=model_name, confidence=confidence_threshold, iou=0.45)
520
- else:
521
- print(f"Using existing model instance for {model_name}")
522
- model_instances[model_name].confidence = confidence_threshold
523
-
524
- class_ids = None
525
- if filter_classes:
526
- class_ids = []
527
- for class_str in filter_classes:
528
- try:
529
- # Extract ID from format "id: name"
530
- class_id = int(class_str.split(":")[0].strip())
531
- class_ids.append(class_id)
532
- except:
533
- continue
534
-
535
- # Execute detection
536
- result_image, result_text, stats = process_image(
537
- image,
538
- model_instances[model_name],
539
- confidence_threshold,
540
- class_ids
541
- )
542
-
543
- # Format the statistics for better display
544
- formatted_stats = format_json_for_display(stats)
545
-
546
- if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
547
- # Create the table
548
- fig, ax = plt.subplots(figsize=(8, 6))
549
- ax.text(0.5, 0.5, "No detection data available",
550
- ha='center', va='center', fontsize=14, fontfamily='Arial')
551
- ax.set_xlim(0, 1)
552
- ax.set_ylim(0, 1)
553
- ax.axis('off')
554
- plot_figure = fig
555
- else:
556
- # prepare visualization data
557
- viz_data = {
558
- "total_objects": stats.get("total_objects", 0),
559
- "average_confidence": stats.get("average_confidence", 0),
560
- "class_data": []
561
- }
562
-
563
- # get the color map
564
- color_mapper_instance = ColorMapper()
565
-
566
- # class data
567
- available_classes = dict(get_all_classes())
568
- for cls_name, cls_stats in stats.get("class_statistics", {}).items():
569
- # search class ID
570
- class_id = -1
571
- for id, name in available_classes.items():
572
- if name == cls_name:
573
- class_id = id
574
- break
575
-
576
- cls_data = {
577
- "name": cls_name,
578
- "class_id": class_id,
579
- "count": cls_stats.get("count", 0),
580
- "average_confidence": cls_stats.get("average_confidence", 0),
581
- "color": color_mapper_instance.get_color(class_id if class_id >= 0 else cls_name)
582
- }
583
-
584
- viz_data["class_data"].append(cls_data)
585
-
586
- # descending order
587
- viz_data["class_data"].sort(key=lambda x: x["count"], reverse=True)
588
-
589
- plot_figure = EvaluationMetrics.create_enhanced_stats_plot(viz_data)
590
-
591
- return result_image, result_text, formatted_stats, plot_figure
592
 
 
593
 
594
  if __name__ == "__main__":
595
  import time
596
-
597
  demo = create_interface()
598
  demo.launch()
 
1
  import os
2
  import numpy as np
 
 
3
  import matplotlib.pyplot as plt
4
  import gradio as gr
 
 
 
5
  from typing import Dict, List, Any, Optional, Tuple
6
+ import spaces
7
 
8
  from detection_model import DetectionModel
9
  from color_mapper import ColorMapper
 
10
  from evaluation_metrics import EvaluationMetrics
11
  from style import Style
12
+ from image_processor import ImageProcessor
13
 
14
+ # Initialize image processor
15
+ image_processor = ImageProcessor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def get_all_classes():
18
  """
19
  Get all available COCO classes from the currently active model or fallback to standard COCO classes
20
+
21
  Returns:
22
  List of tuples (class_id, class_name)
23
  """
 
 
24
  # Try to get class names from any loaded model
25
+ for model_name, model_instance in image_processor.model_instances.items():
26
  if model_instance and model_instance.is_model_loaded:
27
  try:
28
  class_names = model_instance.class_names
29
  return [(idx, name) for idx, name in class_names.items()]
30
  except Exception:
31
  pass
32
+
33
  # Fallback to standard COCO classes
34
  return [
35
  (0, 'person'), (1, 'bicycle'), (2, 'car'), (3, 'motorcycle'), (4, 'airplane'),
36
  (5, 'bus'), (6, 'train'), (7, 'truck'), (8, 'boat'), (9, 'traffic light'),
37
+ (10, 'fire hydrant'), (11, 'stop sign'), (12, 'parking meter'), (13, 'bench'),
38
  (14, 'bird'), (15, 'cat'), (16, 'dog'), (17, 'horse'), (18, 'sheep'), (19, 'cow'),
39
  (20, 'elephant'), (21, 'bear'), (22, 'zebra'), (23, 'giraffe'), (24, 'backpack'),
40
  (25, 'umbrella'), (26, 'handbag'), (27, 'tie'), (28, 'suitcase'), (29, 'frisbee'),
41
  (30, 'skis'), (31, 'snowboard'), (32, 'sports ball'), (33, 'kite'), (34, 'baseball bat'),
42
+ (35, 'baseball glove'), (36, 'skateboard'), (37, 'surfboard'), (38, 'tennis racket'),
43
  (39, 'bottle'), (40, 'wine glass'), (41, 'cup'), (42, 'fork'), (43, 'knife'),
44
+ (44, 'spoon'), (45, 'bowl'), (46, 'banana'), (47, 'apple'), (48, 'sandwich'),
45
  (49, 'orange'), (50, 'broccoli'), (51, 'carrot'), (52, 'hot dog'), (53, 'pizza'),
46
+ (54, 'donut'), (55, 'cake'), (56, 'chair'), (57, 'couch'), (58, 'potted plant'),
47
  (59, 'bed'), (60, 'dining table'), (61, 'toilet'), (62, 'tv'), (63, 'laptop'),
48
+ (64, 'mouse'), (65, 'remote'), (66, 'keyboard'), (67, 'cell phone'), (68, 'microwave'),
49
  (69, 'oven'), (70, 'toaster'), (71, 'sink'), (72, 'refrigerator'), (73, 'book'),
50
+ (74, 'clock'), (75, 'vase'), (76, 'scissors'), (77, 'teddy bear'), (78, 'hair drier'),
51
  (79, 'toothbrush')
52
  ]
53
 
54
+ @spaces.GPU
55
+ def process_and_plot(image, model_name, confidence_threshold, filter_classes=None):
56
+ """
57
+ Process image and create plots for statistics with enhanced visualization
58
+
59
+ Args:
60
+ image: Input image
61
+ model_name: Name of the model to use
62
+ confidence_threshold: Confidence threshold for detection
63
+ filter_classes: Optional list of classes to filter results
64
+
65
+ Returns:
66
+ Tuple of (result_image, result_text, formatted_stats, plot_figure)
67
+ """
68
+ class_ids = None
69
+ if filter_classes:
70
+ class_ids = []
71
+ for class_str in filter_classes:
72
+ try:
73
+ # Extract ID from format "id: name"
74
+ class_id = int(class_str.split(":")[0].strip())
75
+ class_ids.append(class_id)
76
+ except:
77
+ continue
78
+
79
+ # Execute detection
80
+ result_image, result_text, stats = image_processor.process_image(
81
+ image,
82
+ model_name,
83
+ confidence_threshold,
84
+ class_ids
85
+ )
86
+
87
+ # Format the statistics for better display
88
+ formatted_stats = image_processor.format_json_for_display(stats)
89
+
90
+ if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
91
+ # Create the table
92
+ fig, ax = plt.subplots(figsize=(8, 6))
93
+ ax.text(0.5, 0.5, "No detection data available",
94
+ ha='center', va='center', fontsize=14, fontfamily='Arial')
95
+ ax.set_xlim(0, 1)
96
+ ax.set_ylim(0, 1)
97
+ ax.axis('off')
98
+ plot_figure = fig
99
+ else:
100
+ # Prepare visualization data
101
+ available_classes = dict(get_all_classes())
102
+ viz_data = image_processor.prepare_visualization_data(stats, available_classes)
103
+
104
+ # Create plot
105
+ plot_figure = EvaluationMetrics.create_enhanced_stats_plot(viz_data)
106
+
107
+ return result_image, result_text, formatted_stats, plot_figure
108
+
109
  def create_interface():
110
  """創建 Gradio 界面,包含美化的視覺效果"""
111
  css = Style.get_css()
 
114
  available_models = DetectionModel.get_available_models()
115
  model_choices = [model["model_file"] for model in available_models]
116
  model_labels = [f"{model['name']} - {model['inference_speed']}" for model in available_models]
117
+
118
  # 可用類別過濾選項
119
  available_classes = get_all_classes()
120
  class_choices = [f"{id}: {name}" for id, name in available_classes]
121
+
122
  # 創建 Gradio Blocks 界面
123
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="teal", secondary_hue="blue")) as demo:
124
  # 頁面頂部標題
 
132
  """)
133
 
134
  current_model = gr.State("yolov8m.pt") # use medium size model as defualt
135
+
136
+ # 主要內容區
137
  with gr.Row(equal_height=True):
138
  # 左側 - 輸入控制區(可上傳圖片)
139
  with gr.Column(scale=4, elem_classes="input-panel"):
140
  with gr.Group():
141
  gr.HTML('<div class="section-heading">Upload Image</div>')
142
  image_input = gr.Image(type="pil", label="Upload an image", elem_classes="upload-box")
143
+
144
  with gr.Accordion("Advanced Settings", open=False):
145
  with gr.Row():
146
  model_dropdown = gr.Dropdown(
147
  choices=model_choices,
148
+ value="yolov8m.pt",
149
  label="Select Model",
150
  info="Choose different models based on your needs for speed vs. accuracy"
151
  )
152
+
153
  # display model info
154
  model_info = gr.Markdown(DetectionModel.get_model_description("yolov8m.pt"))
155
 
156
  confidence = gr.Slider(
157
+ minimum=0.1,
158
+ maximum=0.9,
159
+ value=0.25,
160
+ step=0.05,
161
  label="Confidence Threshold",
162
  info="Higher values show fewer but more confident detections"
163
  )
164
+
165
  with gr.Accordion("Filter Classes", open=False):
166
  # 常見物件類別快速選擇按鈕
167
  gr.HTML('<div class="section-heading" style="font-size: 1rem;">Common Categories</div>')
 
170
  vehicles_btn = gr.Button("Vehicles", size="sm")
171
  animals_btn = gr.Button("Animals", size="sm")
172
  objects_btn = gr.Button("Common Objects", size="sm")
173
+
174
  # 類別選擇下拉框
175
  class_filter = gr.Dropdown(
176
  choices=class_choices,
 
178
  label="Select Classes to Display",
179
  info="Leave empty to show all detected objects"
180
  )
181
+
182
  # detect buttom
183
  detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="detect-btn")
184
+
185
  # 使用說明區
186
  with gr.Group(elem_classes="how-to-use"):
187
  gr.HTML('<div class="section-heading">How to Use</div>')
 
190
  2. (Optional) Adjust settings like confidence threshold or model size (n, m, x)
191
  3. Optionally filter to specific object classes
192
  4. Click "Detect Objects" button
193
+
194
  The model will identify objects in your image and display them with bounding boxes.
195
+
196
  **Note:** Detection quality depends on image clarity and model settings.
197
  """)
198
+
199
  # 右側 - 結果顯示區
200
  with gr.Column(scale=6, elem_classes="output-panel"):
201
  with gr.Tabs(elem_classes="tabs"):
202
  with gr.Tab("Detection Result"):
203
  result_image = gr.Image(type="pil", label="Detection Result")
204
+
205
+ # details summary
206
  with gr.Group(elem_classes="result-details-box"):
207
  gr.HTML('<div class="section-heading">Detection Details</div>')
208
  # 文本框設置,讓顯示會更寬
 
212
  max_lines=15,
213
  elem_classes="wide-result-text",
214
  elem_id="detection-details",
215
+ container=False,
216
+ scale=2,
217
+ min_width=600
218
  )
219
+
220
  with gr.Tab("Statistics"):
221
  with gr.Row():
222
  with gr.Column(scale=3, elem_classes="plot-column"):
223
  gr.HTML('<div class="section-heading">Object Distribution</div>')
224
  plot_output = gr.Plot(
225
+ label=None,
226
  elem_classes="large-plot-container"
227
  )
228
+
229
  # 右側放 JSON 數據比較清晰
230
  with gr.Column(scale=2, elem_classes="stats-column"):
231
  gr.HTML('<div class="section-heading">Detection Statistics</div>')
 
233
  label=None, # remove label
234
  elem_classes="enhanced-json-display"
235
  )
236
+
237
  detect_btn.click(
238
+ fn=process_and_plot,
239
  inputs=[image_input, current_model, confidence, class_filter],
240
  outputs=[result_image, result_text, stats_json, plot_output]
241
  )
 
246
  inputs=[model_dropdown],
247
  outputs=[current_model, model_info]
248
  )
249
+
250
  # each classes link
251
  people_classes = [0] # 人
252
  vehicles_classes = [1, 2, 3, 4, 5, 6, 7, 8] # 各種車輛
253
  animals_classes = list(range(14, 24)) # COCO 中的動物
254
  common_objects = [41, 42, 43, 44, 45, 67, 73, 74, 76] # 常見家居物品
255
+
256
  # Linked the quik buttom
257
  people_btn.click(
258
  lambda: [f"{id}: {name}" for id, name in available_classes if id in people_classes],
259
  outputs=class_filter
260
  )
261
+
262
  vehicles_btn.click(
263
  lambda: [f"{id}: {name}" for id, name in available_classes if id in vehicles_classes],
264
  outputs=class_filter
265
  )
266
+
267
  animals_btn.click(
268
  lambda: [f"{id}: {name}" for id, name in available_classes if id in animals_classes],
269
  outputs=class_filter
270
  )
271
+
272
  objects_btn.click(
273
  lambda: [f"{id}: {name}" for id, name in available_classes if id in common_objects],
274
  outputs=class_filter
275
  )
276
+
277
  example_images = [
278
  "room_01.jpg",
279
  "street_01.jpg",
280
  "street_02.jpg",
281
  "street_03.jpg"
282
  ]
283
+
284
  # add example images
285
  gr.Examples(
286
  examples=example_images,
287
  inputs=image_input,
288
+ outputs=None,
289
+ fn=None,
290
+ cache_examples=False,
291
  )
292
+
293
+ # Footer
294
  gr.HTML("""
295
  <div class="footer">
296
  <p>Powered by YOLOv8 and Ultralytics • Created with Gradio</p>
297
  <p>Model can detect 80 different classes of objects</p>
298
  </div>
299
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ return demo
302
 
303
  if __name__ == "__main__":
304
  import time
305
+
306
  demo = create_interface()
307
  demo.launch()
image_processor.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ from PIL import Image
6
+ import tempfile
7
+ import uuid
8
+ from typing import Dict, List, Any, Optional, Tuple
9
+
10
+ from detection_model import DetectionModel
11
+ from color_mapper import ColorMapper
12
+ from visualization_helper import VisualizationHelper
13
+ from evaluation_metrics import EvaluationMetrics
14
+
15
+ class ImageProcessor:
16
+ """
17
+ Class for handling image processing and object detection operations
18
+ Separates processing logic from UI components
19
+ """
20
+
21
+ def __init__(self):
22
+ """Initialize the image processor with required components"""
23
+ self.color_mapper = ColorMapper()
24
+ self.model_instances = {}
25
+
26
+ def get_model_instance(self, model_name: str, confidence: float = 0.25, iou: float = 0.35) -> DetectionModel:
27
+ """
28
+ Get or create a model instance based on model name
29
+
30
+ Args:
31
+ model_name: Name of the model to use
32
+ confidence: Confidence threshold for detection
33
+ iou: IoU threshold for non-maximum suppression
34
+
35
+ Returns:
36
+ DetectionModel instance
37
+ """
38
+ if model_name not in self.model_instances:
39
+ print(f"Creating new model instance for {model_name}")
40
+ self.model_instances[model_name] = DetectionModel(
41
+ model_name=model_name,
42
+ confidence=confidence,
43
+ iou=iou
44
+ )
45
+ else:
46
+ print(f"Using existing model instance for {model_name}")
47
+ self.model_instances[model_name].confidence = confidence
48
+
49
+ return self.model_instances[model_name]
50
+
51
+ def process_image(self, image, model_name: str, confidence_threshold: float, filter_classes: Optional[List[int]] = None) -> Tuple[Any, str, Dict]:
52
+ """
53
+ Process an image for object detection
54
+
55
+ Args:
56
+ image: Input image (numpy array or PIL Image)
57
+ model_name: Name of the model to use
58
+ confidence_threshold: Confidence threshold for detection
59
+ filter_classes: Optional list of classes to filter results
60
+
61
+ Returns:
62
+ Tuple of (result_image, result_text, stats_data)
63
+ """
64
+ # Get model instance
65
+ model_instance = self.get_model_instance(model_name, confidence_threshold)
66
+
67
+ # Initialize key variables
68
+ result = None
69
+ stats = {}
70
+ temp_path = None
71
+
72
+ try:
73
+ # Processing input image
74
+ if isinstance(image, np.ndarray):
75
+ # Convert BGR to RGB if needed
76
+ if image.shape[2] == 3:
77
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
78
+ else:
79
+ image_rgb = image
80
+ pil_image = Image.fromarray(image_rgb)
81
+ elif image is None:
82
+ return None, "No image provided. Please upload an image.", {}
83
+ else:
84
+ pil_image = image
85
+
86
+ # Store temp files
87
+ temp_dir = tempfile.gettempdir() # Use system temp directory
88
+ temp_filename = f"temp_{uuid.uuid4().hex}.jpg"
89
+ temp_path = os.path.join(temp_dir, temp_filename)
90
+ pil_image.save(temp_path)
91
+
92
+ # Object detection
93
+ result = model_instance.detect(temp_path)
94
+
95
+ if result is None:
96
+ return None, "Detection failed. Please try again with a different image.", {}
97
+
98
+ # Calculate stats
99
+ stats = EvaluationMetrics.calculate_basic_stats(result)
100
+
101
+ # Add space calculation
102
+ spatial_metrics = EvaluationMetrics.calculate_distance_metrics(result)
103
+ stats["spatial_metrics"] = spatial_metrics
104
+
105
+ # Apply filter if specified
106
+ if filter_classes and len(filter_classes) > 0:
107
+ # Get classes, boxes, confidence
108
+ classes = result.boxes.cls.cpu().numpy().astype(int)
109
+ confs = result.boxes.conf.cpu().numpy()
110
+ boxes = result.boxes.xyxy.cpu().numpy()
111
+
112
+ mask = np.zeros_like(classes, dtype=bool)
113
+ for cls_id in filter_classes:
114
+ mask = np.logical_or(mask, classes == cls_id)
115
+
116
+ filtered_stats = {
117
+ "total_objects": int(np.sum(mask)),
118
+ "class_statistics": {},
119
+ "average_confidence": float(np.mean(confs[mask])) if np.any(mask) else 0,
120
+ "spatial_metrics": stats["spatial_metrics"]
121
+ }
122
+
123
+ # Update stats
124
+ names = result.names
125
+ for cls, conf in zip(classes[mask], confs[mask]):
126
+ cls_name = names[int(cls)]
127
+ if cls_name not in filtered_stats["class_statistics"]:
128
+ filtered_stats["class_statistics"][cls_name] = {
129
+ "count": 0,
130
+ "average_confidence": 0
131
+ }
132
+
133
+ filtered_stats["class_statistics"][cls_name]["count"] += 1
134
+ filtered_stats["class_statistics"][cls_name]["average_confidence"] = conf
135
+
136
+ stats = filtered_stats
137
+
138
+ viz_data = EvaluationMetrics.generate_visualization_data(
139
+ result,
140
+ self.color_mapper.get_all_colors()
141
+ )
142
+
143
+ result_image = VisualizationHelper.visualize_detection(
144
+ temp_path, result, color_mapper=self.color_mapper, figsize=(12, 12), return_pil=True
145
+ )
146
+
147
+ result_text = EvaluationMetrics.format_detection_summary(viz_data)
148
+
149
+ return result_image, result_text, stats
150
+
151
+ except Exception as e:
152
+ error_message = f"Error Occurs: {str(e)}"
153
+ import traceback
154
+ traceback.print_exc()
155
+ print(error_message)
156
+ return None, error_message, {}
157
+
158
+ finally:
159
+ if temp_path and os.path.exists(temp_path):
160
+ try:
161
+ os.remove(temp_path)
162
+ except Exception as e:
163
+ print(f"Cannot delete temp files {temp_path}: {str(e)}")
164
+
165
+ def format_result_text(self, stats: Dict) -> str:
166
+ """
167
+ Format detection statistics into readable text with improved spacing
168
+
169
+ Args:
170
+ stats: Dictionary containing detection statistics
171
+
172
+ Returns:
173
+ Formatted text summary
174
+ """
175
+ if not stats or "total_objects" not in stats:
176
+ return "No objects detected."
177
+
178
+ # 減少不必要的空行
179
+ lines = [
180
+ f"Detected {stats['total_objects']} objects.",
181
+ f"Average confidence: {stats.get('average_confidence', 0):.2f}",
182
+ "Objects by class:"
183
+ ]
184
+
185
+ if "class_statistics" in stats and stats["class_statistics"]:
186
+ # 按計數排序類別
187
+ sorted_classes = sorted(
188
+ stats["class_statistics"].items(),
189
+ key=lambda x: x[1]["count"],
190
+ reverse=True
191
+ )
192
+
193
+ for cls_name, cls_stats in sorted_classes:
194
+ count = cls_stats["count"]
195
+ conf = cls_stats.get("average_confidence", 0)
196
+
197
+ item_text = "item" if count == 1 else "items"
198
+ lines.append(f"• {cls_name}: {count} {item_text} (avg conf: {conf:.2f})")
199
+ else:
200
+ lines.append("No class information available.")
201
+
202
+ # 添加空間信息
203
+ if "spatial_metrics" in stats and "spatial_distribution" in stats["spatial_metrics"]:
204
+ lines.append("Object Distribution:")
205
+
206
+ dist = stats["spatial_metrics"]["spatial_distribution"]
207
+ x_mean = dist.get("x_mean", 0)
208
+ y_mean = dist.get("y_mean", 0)
209
+
210
+ # 描述物體的大致位置
211
+ if x_mean < 0.33:
212
+ h_pos = "on the left side"
213
+ elif x_mean < 0.67:
214
+ h_pos = "in the center"
215
+ else:
216
+ h_pos = "on the right side"
217
+
218
+ if y_mean < 0.33:
219
+ v_pos = "in the upper part"
220
+ elif y_mean < 0.67:
221
+ v_pos = "in the middle"
222
+ else:
223
+ v_pos = "in the lower part"
224
+
225
+ lines.append(f"• Most objects appear {h_pos} {v_pos} of the image")
226
+
227
+ return "\n".join(lines)
228
+
229
+ def format_json_for_display(self, stats: Dict) -> Dict:
230
+ """
231
+ Format statistics JSON for better display
232
+
233
+ Args:
234
+ stats: Raw statistics dictionary
235
+
236
+ Returns:
237
+ Formatted statistics structure for display
238
+ """
239
+ # Create a cleaner copy of the stats for display
240
+ display_stats = {}
241
+
242
+ # Add summary section
243
+ display_stats["summary"] = {
244
+ "total_objects": stats.get("total_objects", 0),
245
+ "average_confidence": round(stats.get("average_confidence", 0), 3)
246
+ }
247
+
248
+ # Add class statistics in a more organized way
249
+ if "class_statistics" in stats and stats["class_statistics"]:
250
+ # Sort classes by count (descending)
251
+ sorted_classes = sorted(
252
+ stats["class_statistics"].items(),
253
+ key=lambda x: x[1].get("count", 0),
254
+ reverse=True
255
+ )
256
+
257
+ class_stats = {}
258
+ for cls_name, cls_data in sorted_classes:
259
+ class_stats[cls_name] = {
260
+ "count": cls_data.get("count", 0),
261
+ "average_confidence": round(cls_data.get("average_confidence", 0), 3)
262
+ }
263
+
264
+ display_stats["detected_objects"] = class_stats
265
+
266
+ # Simplify spatial metrics
267
+ if "spatial_metrics" in stats:
268
+ spatial = stats["spatial_metrics"]
269
+
270
+ # Simplify spatial distribution
271
+ if "spatial_distribution" in spatial:
272
+ dist = spatial["spatial_distribution"]
273
+ display_stats["spatial"] = {
274
+ "distribution": {
275
+ "x_mean": round(dist.get("x_mean", 0), 3),
276
+ "y_mean": round(dist.get("y_mean", 0), 3),
277
+ "x_std": round(dist.get("x_std", 0), 3),
278
+ "y_std": round(dist.get("y_std", 0), 3)
279
+ }
280
+ }
281
+
282
+ # Add simplified size information
283
+ if "size_distribution" in spatial:
284
+ size = spatial["size_distribution"]
285
+ display_stats["spatial"]["size"] = {
286
+ "mean_area": round(size.get("mean_area", 0), 3),
287
+ "min_area": round(size.get("min_area", 0), 3),
288
+ "max_area": round(size.get("max_area", 0), 3)
289
+ }
290
+
291
+ return display_stats
292
+
293
+ def prepare_visualization_data(self, stats: Dict, available_classes: Dict[int, str]) -> Dict:
294
+ """
295
+ Prepare data for visualization based on detection statistics
296
+
297
+ Args:
298
+ stats: Detection statistics
299
+ available_classes: Dictionary of available class IDs and names
300
+
301
+ Returns:
302
+ Visualization data dictionary
303
+ """
304
+ if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
305
+ return {"error": "No detection data available"}
306
+
307
+ # Prepare visualization data
308
+ viz_data = {
309
+ "total_objects": stats.get("total_objects", 0),
310
+ "average_confidence": stats.get("average_confidence", 0),
311
+ "class_data": []
312
+ }
313
+
314
+ # Class data
315
+ for cls_name, cls_stats in stats.get("class_statistics", {}).items():
316
+ # Search class ID
317
+ class_id = -1
318
+ for id, name in available_classes.items():
319
+ if name == cls_name:
320
+ class_id = id
321
+ break
322
+
323
+ cls_data = {
324
+ "name": cls_name,
325
+ "class_id": class_id,
326
+ "count": cls_stats.get("count", 0),
327
+ "average_confidence": cls_stats.get("average_confidence", 0),
328
+ "color": self.color_mapper.get_color(class_id if class_id >= 0 else cls_name)
329
+ }
330
+
331
+ viz_data["class_data"].append(cls_data)
332
+
333
+ # Descending order
334
+ viz_data["class_data"].sort(key=lambda x: x["count"], reverse=True)
335
+
336
+ return viz_data