henry000 commited on
Commit
97e9dcb
·
1 Parent(s): 6777bd1

🔧 [Update] all structure of yaml, more readability

Browse files
yolo/config/config.py CHANGED
@@ -23,8 +23,9 @@ class BlockConfig:
23
 
24
 
25
  @dataclass
26
- class Model:
27
  anchor: AnchorConfig
 
28
  model: Dict[str, BlockConfig]
29
 
30
 
@@ -50,7 +51,10 @@ class DataConfig:
50
  shuffle: bool
51
  batch_size: int
52
  pin_memory: bool
 
 
53
  data_augment: Dict[str, int]
 
54
 
55
 
56
  @dataclass
@@ -92,18 +96,6 @@ class EMAConfig:
92
  decay: float
93
 
94
 
95
- @dataclass
96
- class TrainConfig:
97
- task: str
98
- dataset: DatasetConfig
99
- epoch: int
100
- data: DataConfig
101
- optimizer: OptimizerConfig
102
- loss: LossConfig
103
- scheduler: SchedulerConfig
104
- ema: EMAConfig
105
-
106
-
107
  @dataclass
108
  class NMSConfig:
109
  min_confidence: int
@@ -113,15 +105,35 @@ class NMSConfig:
113
  @dataclass
114
  class InferenceConfig:
115
  task: str
116
- source: Union[str, int]
117
  nms: NMSConfig
 
118
  fast_inference: Optional[None]
119
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  @dataclass
122
  class Config:
123
- task: Union[TrainConfig, InferenceConfig]
124
- model: Model
 
125
  name: str
126
 
127
  device: Union[str, int, List[int]]
 
23
 
24
 
25
  @dataclass
26
+ class ModelConfig:
27
  anchor: AnchorConfig
28
+ class_num: int
29
  model: Dict[str, BlockConfig]
30
 
31
 
 
51
  shuffle: bool
52
  batch_size: int
53
  pin_memory: bool
54
+ cpu_num: int
55
+ image_size: List[int]
56
  data_augment: Dict[str, int]
57
+ source: Optional[Union[str, int]]
58
 
59
 
60
  @dataclass
 
96
  decay: float
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  @dataclass
100
  class NMSConfig:
101
  min_confidence: int
 
105
  @dataclass
106
  class InferenceConfig:
107
  task: str
 
108
  nms: NMSConfig
109
+ data: DataConfig
110
  fast_inference: Optional[None]
111
 
112
 
113
+ @dataclass
114
+ class ValidationConfig:
115
+ task: str
116
+ nms: NMSConfig
117
+ data: DataConfig
118
+
119
+
120
+ @dataclass
121
+ class TrainConfig:
122
+ task: str
123
+ epoch: int
124
+ data: DataConfig
125
+ optimizer: OptimizerConfig
126
+ loss: LossConfig
127
+ scheduler: SchedulerConfig
128
+ ema: EMAConfig
129
+ validation: ValidationConfig
130
+
131
+
132
  @dataclass
133
  class Config:
134
+ task: Union[TrainConfig, InferenceConfig, ValidationConfig]
135
+ dataset: DatasetConfig
136
+ model: ModelConfig
137
  name: str
138
 
139
  device: Union[str, int, List[int]]
yolo/config/config.yaml CHANGED
@@ -7,6 +7,7 @@ name: v9-dev
7
  defaults:
8
  - _self_
9
  - task: train
 
10
  - model: v9-c
11
  - general
12
 
 
7
  defaults:
8
  - _self_
9
  - task: train
10
+ - dataset: coco
11
  - model: v9-c
12
  - general
13
 
yolo/config/{task/dataset → dataset}/coco.yaml RENAMED
@@ -1,6 +1,6 @@
1
  path: data/coco
2
  train: train2017
3
-
4
 
5
  auto_download:
6
  images:
 
1
  path: data/coco
2
  train: train2017
3
+ validation: val2017
4
 
5
  auto_download:
6
  images:
yolo/config/dataset/dev.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ path: data/dev
2
+ train: train
3
+ validation: test
4
+
5
+ auto_download:
yolo/config/model/v9-c.yaml CHANGED
@@ -2,6 +2,8 @@ anchor:
2
  reg_max: 16
3
  strides: [8, 16, 32]
4
 
 
 
5
  model:
6
  backbone:
7
  - Conv:
 
2
  reg_max: 16
3
  strides: [8, 16, 32]
4
 
5
+ class_num: ${class_num}
6
+
7
  model:
8
  backbone:
9
  - Conv:
yolo/config/task/inference.yaml CHANGED
@@ -1,10 +1,9 @@
1
  task: inference
2
- source: demo/images/inference/image.png
3
  fast_inference: # onnx, trt or Empty
4
  data:
5
- batch_size: 16
6
- shuffle: False
7
- pin_memory: True
8
  data_augment: {}
9
  nms:
10
  min_confidence: 0.5
 
1
  task: inference
2
+
3
  fast_inference: # onnx, trt or Empty
4
  data:
5
+ source: demo/images/inference/image.png
6
+ image_size: ${image_size}
 
7
  data_augment: {}
8
  nms:
9
  min_confidence: 0.5
yolo/config/task/train.yaml CHANGED
@@ -1,11 +1,14 @@
1
  task: train
 
2
  defaults:
3
- - dataset: coco
4
 
5
  epoch: 500
6
 
7
  data:
8
  batch_size: 16
 
 
9
  shuffle: True
10
  pin_memory: True
11
  data_augment:
 
1
  task: train
2
+
3
  defaults:
4
+ - validation: ../validation
5
 
6
  epoch: 500
7
 
8
  data:
9
  batch_size: 16
10
+ image_size: ${image_size}
11
+ cpu_num: ${cpu_num}
12
  shuffle: True
13
  pin_memory: True
14
  data_augment:
yolo/config/task/validation.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: validation
2
+
3
+ data:
4
+ batch_size: 16
5
+ image_size: ${image_size}
6
+ cpu_num: ${cpu_num}
7
+ shuffle: False
8
+ pin_memory: True
9
+ data_augment: {}
10
+ nms:
11
+ min_confidence: 0.001
12
+ min_iou: 0.7
yolo/lazy.py CHANGED
@@ -18,10 +18,10 @@ from yolo.utils.logging_utils import custom_logger, validate_log_directory
18
  @hydra.main(config_path="config", config_name="config", version_base=None)
19
  def main(cfg: Config):
20
  custom_logger()
21
- save_path = validate_log_directory(cfg, cfg.name)
22
- dataloader = create_dataloader(cfg)
23
  device = torch.device(cfg.device)
24
- if cfg.task.fast_inference:
25
  model = FastModelLoader(cfg).load_model()
26
  device = torch.device(cfg.device)
27
  else:
 
18
  @hydra.main(config_path="config", config_name="config", version_base=None)
19
  def main(cfg: Config):
20
  custom_logger()
21
+ save_path = validate_log_directory(cfg, exp_name=cfg.name)
22
+ dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
23
  device = torch.device(cfg.device)
24
+ if getattr(cfg.task, "fast_inference", False):
25
  model = FastModelLoader(cfg).load_model()
26
  device = torch.device(cfg.device)
27
  else:
yolo/model/yolo.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
6
  from loguru import logger
7
  from omegaconf import ListConfig, OmegaConf
8
 
9
- from yolo.config.config import Config, Model, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
11
  from yolo.tools.drawer import draw_model
12
  from yolo.utils.logging_utils import log_model_structure
@@ -22,9 +22,9 @@ class YOLO(nn.Module):
22
  parameters, and any other relevant configuration details.
23
  """
24
 
25
- def __init__(self, model_cfg: Model, num_classes: int):
26
  super(YOLO, self).__init__()
27
- self.num_classes = num_classes
28
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
29
  self.model: List[YOLOLayer] = nn.ModuleList()
30
  self.build_model(model_cfg.model)
@@ -126,7 +126,7 @@ def create_model(cfg: Config) -> YOLO:
126
  YOLO: An instance of the model defined by the given configuration.
127
  """
128
  OmegaConf.set_struct(cfg.model, False)
129
- model = YOLO(cfg.model, cfg.class_num)
130
  logger.info("✅ Success load model")
131
  if cfg.weight:
132
  if os.path.exists(cfg.weight):
 
6
  from loguru import logger
7
  from omegaconf import ListConfig, OmegaConf
8
 
9
+ from yolo.config.config import Config, ModelConfig, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
11
  from yolo.tools.drawer import draw_model
12
  from yolo.utils.logging_utils import log_model_structure
 
22
  parameters, and any other relevant configuration details.
23
  """
24
 
25
+ def __init__(self, model_cfg: ModelConfig):
26
  super(YOLO, self).__init__()
27
+ self.num_classes = model_cfg.class_num
28
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
29
  self.model: List[YOLOLayer] = nn.ModuleList()
30
  self.build_model(model_cfg.model)
 
126
  YOLO: An instance of the model defined by the given configuration.
127
  """
128
  OmegaConf.set_struct(cfg.model, False)
129
+ model = YOLO(cfg.model)
130
  logger.info("✅ Success load model")
131
  if cfg.weight:
132
  if os.path.exists(cfg.weight):
yolo/tools/data_loader.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  from os import path
3
  from queue import Empty, Queue
4
  from threading import Event, Thread
5
- from typing import Generator, List, Optional, Tuple, Union
6
 
7
  import cv2
8
  import hydra
@@ -15,7 +15,7 @@ from torch import Tensor
15
  from torch.utils.data import DataLoader, Dataset
16
  from torchvision.transforms import functional as TF
17
 
18
- from yolo.config.config import Config, TrainConfig
19
  from yolo.tools.data_augmentation import (
20
  AugmentationComposer,
21
  HorizontalFlip,
@@ -33,15 +33,15 @@ from yolo.utils.dataset_utils import (
33
 
34
 
35
  class YoloDataset(Dataset):
36
- def __init__(self, config: TrainConfig, phase: str = "train2017", image_size: int = 640):
37
- augment_cfg = config.data.data_augment
38
- phase_name = config.dataset.get(phase, phase)
39
- self.image_size = image_size
40
 
41
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
42
  self.transform = AugmentationComposer(transforms, self.image_size)
43
  self.transform.get_more_data = self.get_more_data
44
- self.data = self.load_data(config.dataset.path, phase_name)
45
 
46
  def load_data(self, dataset_path, phase_name):
47
  """
@@ -161,16 +161,15 @@ class YoloDataset(Dataset):
161
 
162
 
163
  class YoloDataLoader(DataLoader):
164
- def __init__(self, config: Config):
165
  """Initializes the YoloDataLoader with hydra-config files."""
166
- data_cfg = config.task.data
167
- dataset = YoloDataset(config.task, config.task.task)
168
-
169
  super().__init__(
170
  dataset,
171
  batch_size=data_cfg.batch_size,
172
  shuffle=data_cfg.shuffle,
173
- num_workers=config.cpu_num,
174
  pin_memory=data_cfg.pin_memory,
175
  collate_fn=self.collate_fn,
176
  )
@@ -193,31 +192,33 @@ class YoloDataLoader(DataLoader):
193
  target_sizes = [item[1].size(0) for item in batch]
194
  # TODO: Improve readability of these proccess
195
  batch_targets = torch.zeros(batch_size, max(target_sizes), 5)
 
196
  for idx, target_size in enumerate(target_sizes):
197
  batch_targets[idx, :target_size] = batch[idx][1]
 
198
 
199
  batch_images = torch.stack([item[0] for item in batch])
200
 
201
  return batch_images, batch_targets
202
 
203
 
204
- def create_dataloader(config: Config):
205
- if config.task.task == "inference":
206
- return StreamDataLoader(config)
207
 
208
- if config.task.dataset.auto_download:
209
- prepare_dataset(config.task.dataset)
210
 
211
- return YoloDataLoader(config)
212
 
213
 
214
  class StreamDataLoader:
215
- def __init__(self, config: Config):
216
- self.source = config.task.source
217
  self.running = True
218
  self.is_stream = isinstance(self.source, int) or self.source.lower().startswith("rtmp://")
219
 
220
- self.transform = AugmentationComposer([], config.image_size[0])
221
  self.stop_event = Event()
222
 
223
  if self.is_stream:
@@ -301,19 +302,3 @@ class StreamDataLoader:
301
 
302
  def __len__(self):
303
  return self.queue.qsize() if not self.is_stream else 0
304
-
305
-
306
- @hydra.main(config_path="../config", config_name="config", version_base=None)
307
- def main(cfg):
308
- dataloader = create_dataloader(cfg)
309
- draw_bboxes(*next(iter(dataloader)))
310
-
311
-
312
- if __name__ == "__main__":
313
- import sys
314
-
315
- sys.path.append("./")
316
- from utils.logging_utils import custom_logger
317
-
318
- custom_logger()
319
- main()
 
2
  from os import path
3
  from queue import Empty, Queue
4
  from threading import Event, Thread
5
+ from typing import Generator, List, Tuple, Union
6
 
7
  import cv2
8
  import hydra
 
15
  from torch.utils.data import DataLoader, Dataset
16
  from torchvision.transforms import functional as TF
17
 
18
+ from yolo.config.config import DataConfig, DatasetConfig
19
  from yolo.tools.data_augmentation import (
20
  AugmentationComposer,
21
  HorizontalFlip,
 
33
 
34
 
35
  class YoloDataset(Dataset):
36
+ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
37
+ augment_cfg = data_cfg.data_augment
38
+ self.image_size = data_cfg.image_size
39
+ phase_name = dataset_cfg.get(phase, phase)
40
 
41
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
42
  self.transform = AugmentationComposer(transforms, self.image_size)
43
  self.transform.get_more_data = self.get_more_data
44
+ self.data = self.load_data(dataset_cfg.path, phase_name)
45
 
46
  def load_data(self, dataset_path, phase_name):
47
  """
 
161
 
162
 
163
  class YoloDataLoader(DataLoader):
164
+ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
165
  """Initializes the YoloDataLoader with hydra-config files."""
166
+ dataset = YoloDataset(data_cfg, dataset_cfg, task)
167
+ self.image_size = data_cfg.image_size[0]
 
168
  super().__init__(
169
  dataset,
170
  batch_size=data_cfg.batch_size,
171
  shuffle=data_cfg.shuffle,
172
+ num_workers=data_cfg.cpu_num,
173
  pin_memory=data_cfg.pin_memory,
174
  collate_fn=self.collate_fn,
175
  )
 
192
  target_sizes = [item[1].size(0) for item in batch]
193
  # TODO: Improve readability of these proccess
194
  batch_targets = torch.zeros(batch_size, max(target_sizes), 5)
195
+ batch_targets[:, :, 0] = -1
196
  for idx, target_size in enumerate(target_sizes):
197
  batch_targets[idx, :target_size] = batch[idx][1]
198
+ batch_targets[:, :, 1:] *= self.image_size
199
 
200
  batch_images = torch.stack([item[0] for item in batch])
201
 
202
  return batch_images, batch_targets
203
 
204
 
205
+ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
206
+ if task == "inference":
207
+ return StreamDataLoader(data_cfg)
208
 
209
+ if dataset_cfg.auto_download:
210
+ prepare_dataset(dataset_cfg)
211
 
212
+ return YoloDataLoader(data_cfg, dataset_cfg, task)
213
 
214
 
215
  class StreamDataLoader:
216
+ def __init__(self, data_cfg: DataConfig):
217
+ self.source = data_cfg.source
218
  self.running = True
219
  self.is_stream = isinstance(self.source, int) or self.source.lower().startswith("rtmp://")
220
 
221
+ self.transform = AugmentationComposer([], data_cfg.image_size)
222
  self.stop_event = Event()
223
 
224
  if self.is_stream:
 
302
 
303
  def __len__(self):
304
  return self.queue.qsize() if not self.is_stream else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolo/tools/dataset_preparation.py CHANGED
@@ -52,12 +52,12 @@ def check_files(directory, expected_count=None):
52
  return len(files) == expected_count if expected_count is not None else bool(files)
53
 
54
 
55
- def prepare_dataset(cfg: DatasetConfig):
56
  """
57
  Prepares dataset by downloading and unzipping if necessary.
58
  """
59
- data_dir = cfg.path
60
- for data_type, settings in cfg.auto_download.items():
61
  base_url = settings["base_url"]
62
  for dataset_type, dataset_args in settings.items():
63
  if dataset_type == "base_url":
 
52
  return len(files) == expected_count if expected_count is not None else bool(files)
53
 
54
 
55
+ def prepare_dataset(dataset_cfg: DatasetConfig):
56
  """
57
  Prepares dataset by downloading and unzipping if necessary.
58
  """
59
+ data_dir = dataset_cfg.path
60
+ for data_type, settings in dataset_cfg.auto_download.items():
61
  base_url = settings["base_url"]
62
  for dataset_type, dataset_args in settings.items():
63
  if dataset_type == "base_url":
yolo/tools/loss_functions.py CHANGED
@@ -75,14 +75,11 @@ class DFLoss(nn.Module):
75
  class YOLOLoss:
76
  def __init__(self, cfg: Config) -> None:
77
  self.reg_max = cfg.model.anchor.reg_max
78
- self.class_num = cfg.class_num
79
  self.image_size = list(cfg.image_size)
80
  self.strides = cfg.model.anchor.strides
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
 
83
- self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
84
- self.scale_up = torch.tensor(self.image_size * 2, device=device)
85
-
86
  self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device)
87
 
88
  self.cls = BCELoss()
@@ -90,7 +87,7 @@ class YOLOLoss:
90
  self.iou = BoxLoss()
91
 
92
  self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
93
- self.box_converter = AnchorBoxConverter(cfg, device)
94
 
95
  def separate_anchor(self, anchors):
96
  """
@@ -134,7 +131,6 @@ class DualLoss:
134
  self.cls_rate = cfg.task.loss.objective["BCELoss"]
135
 
136
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
137
- targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
138
 
139
  # TODO: Need Refactor this region, make it flexible!
140
  predicts = divide_into_chunks(predicts[0], 2)
 
75
  class YOLOLoss:
76
  def __init__(self, cfg: Config) -> None:
77
  self.reg_max = cfg.model.anchor.reg_max
78
+ self.class_num = cfg.model.class_num
79
  self.image_size = list(cfg.image_size)
80
  self.strides = cfg.model.anchor.strides
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
 
 
 
 
83
  self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device)
84
 
85
  self.cls = BCELoss()
 
87
  self.iou = BoxLoss()
88
 
89
  self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
90
+ self.box_converter = AnchorBoxConverter(cfg.model, self.image_size, device)
91
 
92
  def separate_anchor(self, anchors):
93
  """
 
131
  self.cls_rate = cfg.task.loss.objective["BCELoss"]
132
 
133
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
 
134
 
135
  # TODO: Need Refactor this region, make it flexible!
136
  predicts = divide_into_chunks(predicts[0], 2)
yolo/tools/solver.py CHANGED
@@ -5,12 +5,12 @@ from torch import Tensor
5
  # TODO: We may can't use CUDA?
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
- from yolo.config.config import Config, TrainConfig
9
  from yolo.model.yolo import YOLO
10
- from yolo.tools.data_loader import StreamDataLoader
11
  from yolo.tools.drawer import draw_bboxes
12
  from yolo.tools.loss_functions import get_loss_function
13
- from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
14
  from yolo.utils.logging_utils import ProgressTracker
15
  from yolo.utils.model_utils import (
16
  ExponentialMovingAverage,
@@ -27,7 +27,7 @@ class ModelTrainer:
27
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
28
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
29
  self.loss_fn = get_loss_function(cfg)
30
- self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
31
  self.num_epochs = cfg.task.epoch
32
 
33
  if getattr(train_cfg.ema, "enabled", False):
 
5
  # TODO: We may can't use CUDA?
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
+ from yolo.config.config import Config, TrainConfig, ValidationConfig
9
  from yolo.model.yolo import YOLO
10
+ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
11
  from yolo.tools.drawer import draw_bboxes
12
  from yolo.tools.loss_functions import get_loss_function
13
+ from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms, calculate_map
14
  from yolo.utils.logging_utils import ProgressTracker
15
  from yolo.utils.model_utils import (
16
  ExponentialMovingAverage,
 
27
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
28
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
29
  self.loss_fn = get_loss_function(cfg)
30
+ self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
31
  self.num_epochs = cfg.task.epoch
32
 
33
  if getattr(train_cfg.ema, "enabled", False):
yolo/utils/bounding_box_utils.py CHANGED
@@ -7,7 +7,7 @@ from einops import rearrange
7
  from torch import Tensor
8
  from torchvision.ops import batched_nms
9
 
10
- from yolo.config.config import Config, MatcherConfig, NMSConfig
11
 
12
 
13
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
@@ -125,14 +125,12 @@ def generate_anchors(image_size: List[int], strides: List[int], device):
125
 
126
 
127
  class AnchorBoxConverter:
128
- def __init__(self, cfg: Config, device: torch.device) -> None:
129
- self.reg_max = cfg.model.anchor.reg_max
130
- self.class_num = cfg.class_num
131
- self.image_size = list(cfg.image_size)
132
- self.strides = cfg.model.anchor.strides
133
-
134
- self.scale_up = torch.tensor(self.image_size * 2, device=device)
135
- self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device)
136
  self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
137
 
138
  def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
@@ -255,7 +253,7 @@ class BoxMatcher:
255
  """
256
  predict_cls, predict_bbox = predict.split(self.class_num, dim=-1) # B, HW x (C B) -> B x HW x C, B x HW x B
257
  target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
258
- target_cls = target_cls.long()
259
 
260
  # get valid matrix (each gt appear in which anchor grid)
261
  grid_mask = self.get_valid_matrix(target_bbox)
 
7
  from torch import Tensor
8
  from torchvision.ops import batched_nms
9
 
10
+ from yolo.config.config import MatcherConfig, ModelConfig, NMSConfig
11
 
12
 
13
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
 
125
 
126
 
127
  class AnchorBoxConverter:
128
+ def __init__(self, model_cfg: ModelConfig, image_size: List[int], device: torch.device) -> None:
129
+ self.reg_max = model_cfg.anchor.reg_max
130
+ self.class_num = model_cfg.class_num
131
+ self.strides = model_cfg.anchor.strides
132
+
133
+ self.anchors, self.scaler = generate_anchors(image_size, self.strides, device)
 
 
134
  self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
135
 
136
  def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
 
253
  """
254
  predict_cls, predict_bbox = predict.split(self.class_num, dim=-1) # B, HW x (C B) -> B x HW x C, B x HW x B
255
  target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
256
+ target_cls = target_cls.long().clamp(0)
257
 
258
  # get valid matrix (each gt appear in which anchor grid)
259
  grid_mask = self.get_valid_matrix(target_bbox)
yolo/utils/logging_utils.py CHANGED
@@ -39,7 +39,7 @@ def custom_logger(quite: bool = False):
39
 
40
 
41
  class ProgressTracker:
42
- def __init__(self, cfg: Config, save_path: str, use_wandb: bool = False):
43
  self.progress = Progress(
44
  TextColumn("[progress.description]{task.description}"),
45
  BarColumn(bar_width=None),
 
39
 
40
 
41
  class ProgressTracker:
42
+ def __init__(self, exp_name: str, save_path: str, use_wandb: bool = False):
43
  self.progress = Progress(
44
  TextColumn("[progress.description]{task.description}"),
45
  BarColumn(bar_width=None),
yolo/utils/model_utils.py CHANGED
@@ -63,7 +63,7 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
63
  if hasattr(schedule_cfg, "warmup"):
64
  wepoch = schedule_cfg.warmup.epochs
65
  lambda1 = lambda epoch: 0.1 + 0.9 * (epoch + 1 / wepoch) if epoch < wepoch else 1
66
- lambda2 = lambda epoch: 10 - 9 * (epoch + 1 / wepoch) if epoch < wepoch else 1
67
  warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
68
  schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
69
  return schedule
 
63
  if hasattr(schedule_cfg, "warmup"):
64
  wepoch = schedule_cfg.warmup.epochs
65
  lambda1 = lambda epoch: 0.1 + 0.9 * (epoch + 1 / wepoch) if epoch < wepoch else 1
66
+ lambda2 = lambda epoch: 10 - 9 * (epoch / wepoch) if epoch < wepoch else 1
67
  warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
68
  schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
69
  return schedule