henry000 commited on
Commit
21a413f
·
1 Parent(s): 7d42a25

✨ [New] local mAP calculation, AP.5:.95 - mAP.5

Browse files
yolo/config/task/validation.yaml CHANGED
@@ -8,5 +8,5 @@ data:
8
  pin_memory: True
9
  data_augment: {}
10
  nms:
11
- min_confidence: 0.001
12
- min_iou: 0.7
 
8
  pin_memory: True
9
  data_augment: {}
10
  nms:
11
+ min_confidence: 0.05
12
+ min_iou: 0.9
yolo/tools/solver.py CHANGED
@@ -18,7 +18,7 @@ from yolo.model.yolo import YOLO
18
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
19
  from yolo.tools.drawer import draw_bboxes, draw_model
20
  from yolo.tools.loss_functions import create_loss_function
21
- from yolo.utils.bounding_box_utils import Vec2Box
22
  from yolo.utils.logging_utils import ProgressLogger, log_model_structure
23
  from yolo.utils.model_utils import (
24
  ExponentialMovingAverage,
@@ -198,16 +198,18 @@ class ModelValidator:
198
  def solve(self, dataloader, epoch_idx=-1):
199
  # logger.info("🧪 Start Validation!")
200
  self.model.eval()
201
- predict_json = []
202
  self.progress.start_one_epoch(len(dataloader))
203
  for images, targets, rev_tensor, img_paths in dataloader:
204
  images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
205
  with torch.no_grad():
206
  predicts = self.model(images)
207
- predicts = self.post_proccess(predicts, rev_tensor)
208
- self.progress.one_batch()
 
 
209
 
210
- predict_json.extend(predicts_to_json(img_paths, predicts))
211
  self.progress.finish_one_epoch()
212
  with open(self.json_path, "w") as f:
213
  json.dump(predict_json, f)
 
18
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
19
  from yolo.tools.drawer import draw_bboxes, draw_model
20
  from yolo.tools.loss_functions import create_loss_function
21
+ from yolo.utils.bounding_box_utils import Vec2Box, calculate_map
22
  from yolo.utils.logging_utils import ProgressLogger, log_model_structure
23
  from yolo.utils.model_utils import (
24
  ExponentialMovingAverage,
 
198
  def solve(self, dataloader, epoch_idx=-1):
199
  # logger.info("🧪 Start Validation!")
200
  self.model.eval()
201
+ mAPs, predict_json = [], []
202
  self.progress.start_one_epoch(len(dataloader))
203
  for images, targets, rev_tensor, img_paths in dataloader:
204
  images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
205
  with torch.no_grad():
206
  predicts = self.model(images)
207
+ predicts = self.post_proccess(predicts)
208
+ for idx, predict in enumerate(predicts):
209
+ mAPs.append(calculate_map(predict, targets[idx]))
210
+ self.progress.one_batch(mAP=Tensor(mAPs))
211
 
212
+ predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor))
213
  self.progress.finish_one_epoch()
214
  with open(self.json_path, "w") as f:
215
  json.dump(predict_json, f)
yolo/utils/bounding_box_utils.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
  from loguru import logger
8
- from torch import Tensor
9
  from torchvision.ops import batched_nms
10
 
11
  from yolo.config.config import MatcherConfig, ModelConfig, NMSConfig
@@ -338,8 +338,8 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig):
338
  return predicts_nms
339
 
340
 
341
- def calculate_map(predictions, ground_truths, iou_thresholds):
342
- # TODO: Refactor this block
343
  device = predictions.device
344
  n_preds = predictions.size(0)
345
  n_gts = (ground_truths[:, 0] != -1).sum()
@@ -369,13 +369,16 @@ def calculate_map(predictions, ground_truths, iou_thresholds):
369
  precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
370
  recall = tp_cumsum / (n_gts + 1e-6)
371
 
372
- recall_thresholds = torch.arange(0, 1, 0.1)
373
- precision_at_recall = torch.zeros_like(recall_thresholds)
374
- for i, r in enumerate(recall_thresholds):
375
- precision_at_recall[i] = precision[recall >= r].max().item() if torch.any(recall >= r) else 0
 
 
 
 
376
 
377
- ap = precision_at_recall.mean()
378
  aps.append(ap)
379
 
380
  mean_ap = torch.mean(torch.stack(aps))
381
- return mean_ap, aps
 
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
  from loguru import logger
8
+ from torch import Tensor, arange
9
  from torchvision.ops import batched_nms
10
 
11
  from yolo.config.config import MatcherConfig, ModelConfig, NMSConfig
 
338
  return predicts_nms
339
 
340
 
341
+ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05)):
342
+ # TODO: Refactor this block, Flexible for calculate different mAP condition?
343
  device = predictions.device
344
  n_preds = predictions.size(0)
345
  n_gts = (ground_truths[:, 0] != -1).sum()
 
369
  precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
370
  recall = tp_cumsum / (n_gts + 1e-6)
371
 
372
+ precision = torch.cat([torch.ones(1, device=device), precision, torch.zeros(1, device=device)])
373
+ recall = torch.cat([torch.zeros(1, device=device), recall, torch.ones(1, device=device)])
374
+
375
+ precision, _ = torch.cummax(precision.flip(0), dim=0)
376
+ precision = precision.flip(0)
377
+
378
+ indices = (recall[1:] != recall[:-1]).nonzero(as_tuple=True)[0]
379
+ ap = torch.sum((recall[indices + 1] - recall[indices]) * precision[indices + 1])
380
 
 
381
  aps.append(ap)
382
 
383
  mean_ap = torch.mean(torch.stack(aps))
384
+ return mean_ap, aps[0]
yolo/utils/logging_utils.py CHANGED
@@ -95,9 +95,11 @@ class ProgressLogger(Progress):
95
  self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
96
  self.batch_task = self.add_task("[green]Batches", total=num_batches)
97
 
98
- def one_batch(self, loss_dict: Dict[str, Tensor] = None):
99
  if loss_dict is None:
100
- self.update(self.batch_task, advance=1, description=f"[green]Validating")
 
 
101
  return
102
  if self.use_wandb:
103
  for loss_name, loss_value in loss_dict.items():
 
95
  self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
96
  self.batch_task = self.add_task("[green]Batches", total=num_batches)
97
 
98
+ def one_batch(self, loss_dict: Dict[str, Tensor] = None, mAP: Tensor = None):
99
  if loss_dict is None:
100
+ # refactor this block & class
101
+ mAP_50, mAP_50_95 = mAP.mean(0)
102
+ self.update(self.batch_task, advance=1, description=f"[green]Validating {mAP_50: .2f} {mAP_50_95: .2f}")
103
  return
104
  if self.use_wandb:
105
  for loss_name, loss_value in loss_dict.items():
yolo/utils/model_utils.py CHANGED
@@ -106,7 +106,7 @@ class PostProccess:
106
  self.vec2box = vec2box
107
  self.nms = nms_cfg
108
 
109
- def __call__(self, predict, rev_tensor: Optional[Tensor]):
110
  pred_class, _, pred_bbox = self.vec2box(predict["Main"])
111
  if rev_tensor is not None:
112
  pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
@@ -114,13 +114,15 @@ class PostProccess:
114
  return pred_bbox
115
 
116
 
117
- def predicts_to_json(img_paths, predicts):
118
  """
119
  TODO: function document
120
  turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)
121
  """
122
  batch_json = []
123
- for img_path, bboxes in zip(img_paths, predicts):
 
 
124
  bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
125
  for cls, *pos, conf in bboxes:
126
  bbox = {
 
106
  self.vec2box = vec2box
107
  self.nms = nms_cfg
108
 
109
+ def __call__(self, predict, rev_tensor: Optional[Tensor] = None):
110
  pred_class, _, pred_bbox = self.vec2box(predict["Main"])
111
  if rev_tensor is not None:
112
  pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
 
114
  return pred_bbox
115
 
116
 
117
+ def predicts_to_json(img_paths, predicts, rev_tensor):
118
  """
119
  TODO: function document
120
  turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)
121
  """
122
  batch_json = []
123
+ for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
124
+ scale, shift = box_reverse.split([1, 4])
125
+ bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
126
  bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
127
  for cls, *pos, conf in bboxes:
128
  bbox = {