henry000 commited on
Commit
2ae492a
·
1 Parent(s): 849d290

✨ [Add] yolov9 loss function, align to origin v9

Browse files
config/config.py CHANGED
@@ -2,9 +2,15 @@ from dataclasses import dataclass
2
  from typing import Dict, List, Union
3
 
4
 
 
 
 
 
 
 
5
  @dataclass
6
  class Model:
7
- anchor: List[List[int]]
8
  model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
9
 
10
 
@@ -20,6 +26,8 @@ class DataLoaderConfig:
20
  shuffle: bool
21
  num_workers: int
22
  pin_memory: bool
 
 
23
 
24
 
25
  @dataclass
@@ -52,11 +60,19 @@ class EMAConfig:
52
  decay: float
53
 
54
 
 
 
 
 
 
 
 
55
  @dataclass
56
  class TrainConfig:
57
  optimizer: OptimizerConfig
58
  scheduler: SchedulerConfig
59
  ema: EMAConfig
 
60
 
61
 
62
  @dataclass
 
2
  from typing import Dict, List, Union
3
 
4
 
5
+ @dataclass
6
+ class AnchorConfig:
7
+ reg_max: int
8
+ strides: List[int]
9
+
10
+
11
  @dataclass
12
  class Model:
13
+ anchor: AnchorConfig
14
  model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
15
 
16
 
 
26
  shuffle: bool
27
  num_workers: int
28
  pin_memory: bool
29
+ image_size: List[int]
30
+ class_num: int
31
 
32
 
33
  @dataclass
 
60
  decay: float
61
 
62
 
63
+ @dataclass
64
+ class MatcherConfig:
65
+ iou: str
66
+ topk: int
67
+ factor: Dict[str, int]
68
+
69
+
70
  @dataclass
71
  class TrainConfig:
72
  optimizer: OptimizerConfig
73
  scheduler: SchedulerConfig
74
  ema: EMAConfig
75
+ matcher: MatcherConfig
76
 
77
 
78
  @dataclass
config/hyper/default.yaml CHANGED
@@ -3,12 +3,28 @@ data:
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
 
 
6
  train:
7
  optimizer:
8
  type: Adam
9
  args:
10
  lr: 0.001
11
  weight_decay: 0.0001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  scheduler:
13
  type: StepLR
14
  args:
 
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
6
+ class_num: 80
7
+ image_size: [640, 640]
8
  train:
9
  optimizer:
10
  type: Adam
11
  args:
12
  lr: 0.001
13
  weight_decay: 0.0001
14
+ loss:
15
+ BCELoss:
16
+ args:
17
+ BoxLoss:
18
+ args:
19
+ alpha: 0.1
20
+ DFLoss:
21
+ args:
22
+ matcher:
23
+ iou: CIoU
24
+ topk: 10
25
+ factor:
26
+ iou: 6.0
27
+ cls: 0.5
28
  scheduler:
29
  type: StepLR
30
  args:
config/model/v7-base.yaml CHANGED
@@ -1,5 +1,9 @@
1
  nc: 80
2
 
 
 
 
 
3
  model:
4
  backbone:
5
  - Conv:
 
1
  nc: 80
2
 
3
+ anchor:
4
+ reg_max: 16
5
+ strides: [8, 16, 32]
6
+
7
  model:
8
  backbone:
9
  - Conv:
tools/bbox_helper.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+ from config.config import MatcherConfig
9
+
10
+
11
+ def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
12
+ metrics = metrics.lower()
13
+ EPS = 1e-9
14
+ dtype = bbox1.dtype
15
+ bbox1 = bbox1.to(torch.float32)
16
+ bbox2 = bbox2.to(torch.float32)
17
+
18
+ # Expand dimensions if necessary
19
+ if bbox1.ndim == 2 and bbox2.ndim == 2:
20
+ bbox1 = bbox1.unsqueeze(1) # (Ax4) -> (Ax1x4)
21
+ bbox2 = bbox2.unsqueeze(0) # (Bx4) -> (1xBx4)
22
+ elif bbox1.ndim == 3 and bbox2.ndim == 3:
23
+ bbox1 = bbox1.unsqueeze(2) # (BZxAx4) -> (BZxAx1x4)
24
+ bbox2 = bbox2.unsqueeze(1) # (BZxBx4) -> (BZx1xBx4)
25
+
26
+ # Calculate intersection coordinates
27
+ xmin_inter = torch.max(bbox1[..., 0], bbox2[..., 0])
28
+ ymin_inter = torch.max(bbox1[..., 1], bbox2[..., 1])
29
+ xmax_inter = torch.min(bbox1[..., 2], bbox2[..., 2])
30
+ ymax_inter = torch.min(bbox1[..., 3], bbox2[..., 3])
31
+
32
+ # Calculate intersection area
33
+ intersection_area = torch.clamp(xmax_inter - xmin_inter, min=0) * torch.clamp(ymax_inter - ymin_inter, min=0)
34
+
35
+ # Calculate area of each bbox
36
+ area_bbox1 = (bbox1[..., 2] - bbox1[..., 0]) * (bbox1[..., 3] - bbox1[..., 1])
37
+ area_bbox2 = (bbox2[..., 2] - bbox2[..., 0]) * (bbox2[..., 3] - bbox2[..., 1])
38
+
39
+ # Calculate union area
40
+ union_area = area_bbox1 + area_bbox2 - intersection_area
41
+
42
+ # Calculate IoU
43
+ iou = intersection_area / (union_area + EPS)
44
+ if metrics == "iou":
45
+ return iou
46
+
47
+ # Calculate centroid distance
48
+ cx1 = (bbox1[..., 2] + bbox1[..., 0]) / 2
49
+ cy1 = (bbox1[..., 3] + bbox1[..., 1]) / 2
50
+ cx2 = (bbox2[..., 2] + bbox2[..., 0]) / 2
51
+ cy2 = (bbox2[..., 3] + bbox2[..., 1]) / 2
52
+ cent_dis = (cx1 - cx2) ** 2 + (cy1 - cy2) ** 2
53
+
54
+ # Calculate diagonal length of the smallest enclosing box
55
+ c_x = torch.max(bbox1[..., 2], bbox2[..., 2]) - torch.min(bbox1[..., 0], bbox2[..., 0])
56
+ c_y = torch.max(bbox1[..., 3], bbox2[..., 3]) - torch.min(bbox1[..., 1], bbox2[..., 1])
57
+ diag_dis = c_x**2 + c_y**2 + EPS
58
+
59
+ diou = iou - (cent_dis / diag_dis)
60
+ if metrics == "diou":
61
+ return diou
62
+
63
+ # Compute aspect ratio penalty term
64
+ arctan = torch.atan((bbox1[..., 2] - bbox1[..., 0]) / (bbox1[..., 3] - bbox1[..., 1] + EPS)) - torch.atan(
65
+ (bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
66
+ )
67
+ v = (4 / (math.pi**2)) * (arctan**2)
68
+ alpha = v / (v - iou + 1 + EPS)
69
+ # Compute CIoU
70
+ ciou = diou - alpha * v
71
+ return ciou.to(dtype)
72
+
73
+
74
+ def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
75
+ data_type = bbox.dtype
76
+ in_type, out_type = indicator.replace(" ", "").split("->")
77
+
78
+ if in_type not in ["xyxy", "xywh", "xycwh"] or out_type not in ["xyxy", "xywh", "xycwh"]:
79
+ raise ValueError("Invalid input or output format")
80
+
81
+ if in_type == "xywh":
82
+ x_min = bbox[..., 0]
83
+ y_min = bbox[..., 1]
84
+ x_max = bbox[..., 0] + bbox[..., 2]
85
+ y_max = bbox[..., 1] + bbox[..., 3]
86
+ elif in_type == "xyxy":
87
+ x_min = bbox[..., 0]
88
+ y_min = bbox[..., 1]
89
+ x_max = bbox[..., 2]
90
+ y_max = bbox[..., 3]
91
+ elif in_type == "xycwh":
92
+ x_min = bbox[..., 0] - bbox[..., 2] / 2
93
+ y_min = bbox[..., 1] - bbox[..., 3] / 2
94
+ x_max = bbox[..., 0] + bbox[..., 2] / 2
95
+ y_max = bbox[..., 1] + bbox[..., 3] / 2
96
+
97
+ if out_type == "xywh":
98
+ bbox = torch.stack([x_min, y_min, x_max - x_min, y_max - y_min], dim=-1)
99
+ elif out_type == "xyxy":
100
+ bbox = torch.stack([x_min, y_min, x_max, y_max], dim=-1)
101
+ elif out_type == "xycwh":
102
+ bbox = torch.stack([(x_min + x_max) / 2, (y_min + y_max) / 2, x_max - x_min, y_max - y_min], dim=-1)
103
+
104
+ return bbox.to(dtype=data_type)
105
+
106
+
107
+ def make_anchor(image_size: List[int], strides: List[int], device):
108
+ W, H = image_size
109
+ anchors = []
110
+ scaler = []
111
+ for stride in strides:
112
+ anchor_num = W // stride * H // stride
113
+ scaler.append(torch.full((anchor_num,), stride, device=device))
114
+ shift = stride // 2
115
+ x = torch.arange(0, W, stride, device=device) + shift
116
+ y = torch.arange(0, H, stride, device=device) + shift
117
+ anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij")
118
+ anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1)
119
+ anchors.append(anchor)
120
+ all_anchors = torch.cat(anchors, dim=0)
121
+ all_scalers = torch.cat(scaler, dim=0)
122
+ return all_anchors, all_scalers
123
+
124
+
125
+ class BoxMatcher:
126
+ def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
127
+ self.class_num = class_num
128
+ self.anchors = anchors
129
+ for attr_name in cfg:
130
+ setattr(self, attr_name, cfg[attr_name])
131
+
132
+ def get_valid_matrix(self, target_bbox: Tensor):
133
+ """
134
+ Get a boolean mask that indicates whether each target bounding box overlaps with each anchor.
135
+
136
+ Args:
137
+ target_bbox [batch x targets x 4]: The bounding box of each targets.
138
+ Returns:
139
+ [batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors.
140
+ """
141
+ Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3)
142
+ anchors = self.anchors[None, None] # add a axis at first, second dimension
143
+ anchors_x, anchors_y = anchors.unbind(dim=3)
144
+ target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax)
145
+ target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax)
146
+ target_on_anchor = target_in_x & target_in_y
147
+ return target_on_anchor
148
+
149
+ def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
150
+ """
151
+ Get the (predicted class' probabilities) corresponding to the target classes across all anchors
152
+
153
+ Args:
154
+ predict_cls [batch x class x anchors]: The predicted probabilities for each class across each anchor.
155
+ target_cls [batch x targets]: The class index for each target.
156
+
157
+ Returns:
158
+ [batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
159
+ """
160
+ target_cls = target_cls.expand(-1, -1, 8400)
161
+ predict_cls = predict_cls.transpose(1, 2)
162
+ cls_probabilities = torch.gather(predict_cls, 1, target_cls)
163
+ return cls_probabilities
164
+
165
+ def get_iou_matrix(self, predict_bbox, target_bbox) -> Tensor:
166
+ """
167
+ Get the IoU between each target bounding box and each predicted bounding box.
168
+
169
+ Args:
170
+ predict_bbox [batch x predicts x 4]: Bounding box with [x1, y1, x2, y2].
171
+ target_bbox [batch x targets x 4]: Bounding box with [x1, y1, x2, y2].
172
+ Returns:
173
+ [batch x targets x predicts]: The IoU scores between each target and predicted.
174
+ """
175
+ return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)
176
+
177
+ def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
178
+ """
179
+ Filter the top-k suitability of targets for each anchor.
180
+
181
+ Args:
182
+ target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
183
+ topk (int, optional): Number of top scores to retain per anchor.
184
+
185
+ Returns:
186
+ topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
187
+ topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
188
+ """
189
+ values, indices = target_matrix.topk(topk, dim=-1)
190
+ topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
191
+ topk_targets.scatter_(dim=-1, index=indices, src=values)
192
+ topk_masks = topk_targets > 0
193
+ return topk_targets, topk_masks
194
+
195
+ def filter_duplicates(self, target_matrix: Tensor):
196
+ """
197
+ Filter the maximum suitability target index of each anchor.
198
+
199
+ Args:
200
+ target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
201
+
202
+ Returns:
203
+ unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
204
+ """
205
+ unique_indices = target_matrix.argmax(dim=1)
206
+ return unique_indices[..., None]
207
+
208
+ def __call__(self, target: Tensor, predict: Tensor) -> Tuple[Tensor, Tensor]:
209
+ """
210
+ 1. For each anchor prediction, find the highest suitability targets
211
+ 2. Select the targets
212
+ 2. Noramlize the class probilities of targets
213
+ """
214
+ predict_cls, predict_bbox = predict.split(self.class_num, dim=-1) # B, HW x (C B) -> B x HW x C, B x HW x B
215
+ target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
216
+ target_cls = target_cls.long()
217
+
218
+ # get valid matrix (each gt appear in which anchor grid)
219
+ grid_mask = self.get_valid_matrix(target_bbox)
220
+
221
+ # get iou matrix (iou with each gt bbox and each predict anchor)
222
+ iou_mat = self.get_iou_matrix(predict_bbox, target_bbox)
223
+
224
+ # get cls matrix (cls prob with each gt class and each predict class)
225
+ cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
226
+
227
+ # TODO: alpha and beta should be set at hydra
228
+ target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
229
+
230
+ # choose topk
231
+ # TODO: topk should be set at hydra
232
+ topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
233
+
234
+ # delete one anchor pred assign to mutliple gts
235
+ unique_indices = self.filter_duplicates(topk_targets)
236
+
237
+ # TODO: do we need grid_mask? Filter the valid groud truth
238
+ valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
239
+
240
+ align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
241
+ align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
242
+ align_cls = F.one_hot(align_cls, self.class_num)
243
+
244
+ # normalize class ditribution
245
+ max_target = target_matrix.amax(dim=-1, keepdim=True)
246
+ max_iou = iou_mat.amax(dim=-1, keepdim=True)
247
+ normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
248
+ normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
249
+ align_cls = align_cls * normalize_term * valid_mask[:, :, None]
250
+
251
+ return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
utils/loss.py CHANGED
@@ -1,2 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def get_loss_function(*args, **kwargs):
2
  raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from typing import Any, List
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from hydra import main
10
+ from loguru import logger
11
+ from torch import Tensor, nn
12
+ from torch.nn import BCEWithLogitsLoss
13
+
14
+ sys.path.append("./")
15
+ from config.config import Config
16
+ from tools.bbox_helper import BoxMatcher, calculate_iou, make_anchor, transform_bbox
17
+
18
+
19
  def get_loss_function(*args, **kwargs):
20
  raise NotImplementedError
21
+
22
+
23
+ class BCELoss(nn.Module):
24
+ def __init__(self) -> None:
25
+ super().__init__()
26
+ self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=torch.device("cuda")), reduction="none")
27
+
28
+ def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any:
29
+ return self.bce(predicts_cls, targets_cls).sum() / cls_norm
30
+
31
+
32
+ class BoxLoss(nn.Module):
33
+ def __init__(self) -> None:
34
+ super().__init__()
35
+
36
+ def forward(
37
+ self, predicts_bbox: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
38
+ ) -> Any:
39
+ valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
40
+ picked_predict = predicts_bbox[valid_bbox].view(-1, 4)
41
+ picked_targets = targets_bbox[valid_bbox].view(-1, 4)
42
+
43
+ iou = calculate_iou(picked_predict, picked_targets, "ciou").diag()
44
+ loss_iou = 1.0 - iou
45
+ loss_iou = (loss_iou * box_norm).sum() / cls_norm
46
+ return loss_iou
47
+
48
+
49
+ class DFLoss(nn.Module):
50
+ def __init__(self, anchors: Tensor, scaler: Tensor, reg_max: int) -> None:
51
+ super().__init__()
52
+ self.anchors = anchors
53
+ self.scaler = scaler
54
+ self.reg_max = reg_max
55
+
56
+ def forward(
57
+ self, predicts_anc: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
58
+ ) -> Any:
59
+ valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
60
+ bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
61
+ anchors_norm = (self.anchors / self.scaler[:, None])[None]
62
+ targets_dist = torch.cat(((anchors_norm - bbox_lt), (bbox_rb - anchors_norm)), -1).clamp(0, self.reg_max - 1.01)
63
+ picked_targets = targets_dist[valid_bbox].view(-1)
64
+ picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
65
+
66
+ label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1
67
+ weight_left, weight_right = label_right - picked_targets, picked_targets - label_left
68
+
69
+ loss_left = F.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none")
70
+ loss_right = F.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none")
71
+ loss_dfl = loss_left * weight_left + loss_right * weight_right
72
+ loss_dfl = loss_dfl.view(-1, 4).mean(-1)
73
+ loss_dfl = (loss_dfl * box_norm).sum() / cls_norm
74
+ return loss_dfl
75
+
76
+
77
+ class YOLOLoss:
78
+ def __init__(self, cfg: Config) -> None:
79
+ self.reg_max = cfg.model.anchor.reg_max
80
+ self.class_num = cfg.hyper.data.class_num
81
+ self.image_size = list(cfg.hyper.data.image_size)
82
+ self.strides = cfg.model.anchor.strides
83
+ device = torch.device("cuda")
84
+
85
+ self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
86
+ self.scale_up = torch.tensor(self.image_size * 2, device=device)
87
+
88
+ self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
89
+
90
+ self.cls = BCELoss()
91
+ self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
92
+ self.iou = BoxLoss()
93
+
94
+ self.matcher = BoxMatcher(cfg.hyper.train.matcher, self.class_num, self.anchors)
95
+
96
+ def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
97
+ """
98
+ args:
99
+ [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
100
+ return:
101
+ [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
102
+ """
103
+ preds = []
104
+ for pred in predicts:
105
+ preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
106
+ preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
107
+
108
+ preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
109
+ preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
110
+
111
+ pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
112
+
113
+ lt, rb = pred_LTRB.chunk(2, dim=-1)
114
+ pred_minXY = self.anchors - lt
115
+ pred_maxXY = self.anchors + rb
116
+ predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
117
+
118
+ return predicts, preds_anc
119
+
120
+ def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
121
+ """
122
+ return List:
123
+ """
124
+ targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
125
+ bbox_num = targets[:, 0].int().bincount()
126
+ batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
127
+ for instance_idx, bbox_num in enumerate(bbox_num):
128
+ instance_targets = targets[targets[:, 0] == instance_idx]
129
+ batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
130
+ return batch_targets
131
+
132
+ def separate_anchor(self, anchors):
133
+ """
134
+ separate anchor and bbouding box
135
+ """
136
+ anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
137
+ anchors_box = anchors_box / self.scaler[None, :, None]
138
+ return anchors_cls, anchors_box
139
+
140
+ @torch.autocast("cuda")
141
+ def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tensor:
142
+ # Batch_Size x (Anchor + Class) x H x W
143
+ tlist = [time.time()]
144
+ # TODO: check datatype, why targets has a little bit error with origin version
145
+ predicts, predicts_anc = self.parse_predicts(predicts[0])
146
+ targets = self.parse_targets(targets)
147
+
148
+ align_targets, valid_masks = self.matcher(targets, predicts)
149
+ # calculate loss between with instance and predict
150
+
151
+ targets_cls, targets_bbox = self.separate_anchor(align_targets)
152
+ predicts_cls, predicts_bbox = self.separate_anchor(predicts)
153
+
154
+ cls_norm = targets_cls.sum()
155
+ box_norm = targets_cls.sum(-1)[valid_masks]
156
+
157
+ ## -- CLS -- ##
158
+ loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
159
+ ## -- IOU -- ##
160
+ loss_iou = self.iou(predicts_bbox, targets_bbox, valid_masks, box_norm, cls_norm)
161
+ ## -- DFL -- ##
162
+ loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
163
+
164
+ logger.info("Loss IoU: {:.5f}, DFL: {:.5f}, CLS: {:.5f}", loss_iou, loss_dfl, loss_cls)
165
+ tlist.append(time.time())
166
+ logger.info(f"Calculate Loss Run Time {np.diff(np.array(tlist)) * 1e3} ms")
167
+
168
+
169
+ @main(config_path="../config", config_name="config", version_base=None)
170
+ def main(cfg):
171
+ losser = YOLOLoss(cfg)
172
+ targets = torch.load("targets.pt")
173
+ predicts = torch.load("predicts.pt")
174
+ losser(predicts, targets)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ import sys
179
+
180
+ sys.path.append("./")
181
+ from tools.log_helper import custom_logger
182
+
183
+ custom_logger()
184
+ main()