🔧 [Update] all structure of yaml, more readability
Browse files- yolo/config/config.py +28 -16
- yolo/config/config.yaml +1 -0
- yolo/config/{task/dataset → dataset}/coco.yaml +1 -1
- yolo/config/dataset/dev.yaml +5 -0
- yolo/config/model/v9-c.yaml +2 -0
- yolo/config/task/inference.yaml +3 -4
- yolo/config/task/train.yaml +4 -1
- yolo/config/task/validation.yaml +12 -0
- yolo/lazy.py +3 -3
- yolo/model/yolo.py +4 -4
- yolo/tools/data_loader.py +22 -37
- yolo/tools/dataset_preparation.py +3 -3
- yolo/tools/loss_functions.py +2 -6
- yolo/tools/solver.py +4 -4
- yolo/utils/bounding_box_utils.py +8 -10
- yolo/utils/logging_utils.py +1 -1
- yolo/utils/model_utils.py +1 -1
yolo/config/config.py
CHANGED
@@ -23,8 +23,9 @@ class BlockConfig:
|
|
23 |
|
24 |
|
25 |
@dataclass
|
26 |
-
class
|
27 |
anchor: AnchorConfig
|
|
|
28 |
model: Dict[str, BlockConfig]
|
29 |
|
30 |
|
@@ -50,7 +51,10 @@ class DataConfig:
|
|
50 |
shuffle: bool
|
51 |
batch_size: int
|
52 |
pin_memory: bool
|
|
|
|
|
53 |
data_augment: Dict[str, int]
|
|
|
54 |
|
55 |
|
56 |
@dataclass
|
@@ -92,18 +96,6 @@ class EMAConfig:
|
|
92 |
decay: float
|
93 |
|
94 |
|
95 |
-
@dataclass
|
96 |
-
class TrainConfig:
|
97 |
-
task: str
|
98 |
-
dataset: DatasetConfig
|
99 |
-
epoch: int
|
100 |
-
data: DataConfig
|
101 |
-
optimizer: OptimizerConfig
|
102 |
-
loss: LossConfig
|
103 |
-
scheduler: SchedulerConfig
|
104 |
-
ema: EMAConfig
|
105 |
-
|
106 |
-
|
107 |
@dataclass
|
108 |
class NMSConfig:
|
109 |
min_confidence: int
|
@@ -113,15 +105,35 @@ class NMSConfig:
|
|
113 |
@dataclass
|
114 |
class InferenceConfig:
|
115 |
task: str
|
116 |
-
source: Union[str, int]
|
117 |
nms: NMSConfig
|
|
|
118 |
fast_inference: Optional[None]
|
119 |
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
@dataclass
|
122 |
class Config:
|
123 |
-
task: Union[TrainConfig, InferenceConfig]
|
124 |
-
|
|
|
125 |
name: str
|
126 |
|
127 |
device: Union[str, int, List[int]]
|
|
|
23 |
|
24 |
|
25 |
@dataclass
|
26 |
+
class ModelConfig:
|
27 |
anchor: AnchorConfig
|
28 |
+
class_num: int
|
29 |
model: Dict[str, BlockConfig]
|
30 |
|
31 |
|
|
|
51 |
shuffle: bool
|
52 |
batch_size: int
|
53 |
pin_memory: bool
|
54 |
+
cpu_num: int
|
55 |
+
image_size: List[int]
|
56 |
data_augment: Dict[str, int]
|
57 |
+
source: Optional[Union[str, int]]
|
58 |
|
59 |
|
60 |
@dataclass
|
|
|
96 |
decay: float
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
@dataclass
|
100 |
class NMSConfig:
|
101 |
min_confidence: int
|
|
|
105 |
@dataclass
|
106 |
class InferenceConfig:
|
107 |
task: str
|
|
|
108 |
nms: NMSConfig
|
109 |
+
data: DataConfig
|
110 |
fast_inference: Optional[None]
|
111 |
|
112 |
|
113 |
+
@dataclass
|
114 |
+
class ValidationConfig:
|
115 |
+
task: str
|
116 |
+
nms: NMSConfig
|
117 |
+
data: DataConfig
|
118 |
+
|
119 |
+
|
120 |
+
@dataclass
|
121 |
+
class TrainConfig:
|
122 |
+
task: str
|
123 |
+
epoch: int
|
124 |
+
data: DataConfig
|
125 |
+
optimizer: OptimizerConfig
|
126 |
+
loss: LossConfig
|
127 |
+
scheduler: SchedulerConfig
|
128 |
+
ema: EMAConfig
|
129 |
+
validation: ValidationConfig
|
130 |
+
|
131 |
+
|
132 |
@dataclass
|
133 |
class Config:
|
134 |
+
task: Union[TrainConfig, InferenceConfig, ValidationConfig]
|
135 |
+
dataset: DatasetConfig
|
136 |
+
model: ModelConfig
|
137 |
name: str
|
138 |
|
139 |
device: Union[str, int, List[int]]
|
yolo/config/config.yaml
CHANGED
@@ -7,6 +7,7 @@ name: v9-dev
|
|
7 |
defaults:
|
8 |
- _self_
|
9 |
- task: train
|
|
|
10 |
- model: v9-c
|
11 |
- general
|
12 |
|
|
|
7 |
defaults:
|
8 |
- _self_
|
9 |
- task: train
|
10 |
+
- dataset: coco
|
11 |
- model: v9-c
|
12 |
- general
|
13 |
|
yolo/config/{task/dataset → dataset}/coco.yaml
RENAMED
@@ -1,6 +1,6 @@
|
|
1 |
path: data/coco
|
2 |
train: train2017
|
3 |
-
|
4 |
|
5 |
auto_download:
|
6 |
images:
|
|
|
1 |
path: data/coco
|
2 |
train: train2017
|
3 |
+
validation: val2017
|
4 |
|
5 |
auto_download:
|
6 |
images:
|
yolo/config/dataset/dev.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
path: data/dev
|
2 |
+
train: train
|
3 |
+
validation: test
|
4 |
+
|
5 |
+
auto_download:
|
yolo/config/model/v9-c.yaml
CHANGED
@@ -2,6 +2,8 @@ anchor:
|
|
2 |
reg_max: 16
|
3 |
strides: [8, 16, 32]
|
4 |
|
|
|
|
|
5 |
model:
|
6 |
backbone:
|
7 |
- Conv:
|
|
|
2 |
reg_max: 16
|
3 |
strides: [8, 16, 32]
|
4 |
|
5 |
+
class_num: ${class_num}
|
6 |
+
|
7 |
model:
|
8 |
backbone:
|
9 |
- Conv:
|
yolo/config/task/inference.yaml
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
task: inference
|
2 |
-
|
3 |
fast_inference: # onnx, trt or Empty
|
4 |
data:
|
5 |
-
|
6 |
-
|
7 |
-
pin_memory: True
|
8 |
data_augment: {}
|
9 |
nms:
|
10 |
min_confidence: 0.5
|
|
|
1 |
task: inference
|
2 |
+
|
3 |
fast_inference: # onnx, trt or Empty
|
4 |
data:
|
5 |
+
source: demo/images/inference/image.png
|
6 |
+
image_size: ${image_size}
|
|
|
7 |
data_augment: {}
|
8 |
nms:
|
9 |
min_confidence: 0.5
|
yolo/config/task/train.yaml
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
task: train
|
|
|
2 |
defaults:
|
3 |
-
-
|
4 |
|
5 |
epoch: 500
|
6 |
|
7 |
data:
|
8 |
batch_size: 16
|
|
|
|
|
9 |
shuffle: True
|
10 |
pin_memory: True
|
11 |
data_augment:
|
|
|
1 |
task: train
|
2 |
+
|
3 |
defaults:
|
4 |
+
- validation: ../validation
|
5 |
|
6 |
epoch: 500
|
7 |
|
8 |
data:
|
9 |
batch_size: 16
|
10 |
+
image_size: ${image_size}
|
11 |
+
cpu_num: ${cpu_num}
|
12 |
shuffle: True
|
13 |
pin_memory: True
|
14 |
data_augment:
|
yolo/config/task/validation.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
task: validation
|
2 |
+
|
3 |
+
data:
|
4 |
+
batch_size: 16
|
5 |
+
image_size: ${image_size}
|
6 |
+
cpu_num: ${cpu_num}
|
7 |
+
shuffle: False
|
8 |
+
pin_memory: True
|
9 |
+
data_augment: {}
|
10 |
+
nms:
|
11 |
+
min_confidence: 0.001
|
12 |
+
min_iou: 0.7
|
yolo/lazy.py
CHANGED
@@ -18,10 +18,10 @@ from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
|
18 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
19 |
def main(cfg: Config):
|
20 |
custom_logger()
|
21 |
-
save_path = validate_log_directory(cfg, cfg.name)
|
22 |
-
dataloader = create_dataloader(cfg)
|
23 |
device = torch.device(cfg.device)
|
24 |
-
if cfg.task
|
25 |
model = FastModelLoader(cfg).load_model()
|
26 |
device = torch.device(cfg.device)
|
27 |
else:
|
|
|
18 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
19 |
def main(cfg: Config):
|
20 |
custom_logger()
|
21 |
+
save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
22 |
+
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
23 |
device = torch.device(cfg.device)
|
24 |
+
if getattr(cfg.task, "fast_inference", False):
|
25 |
model = FastModelLoader(cfg).load_model()
|
26 |
device = torch.device(cfg.device)
|
27 |
else:
|
yolo/model/yolo.py
CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
|
|
6 |
from loguru import logger
|
7 |
from omegaconf import ListConfig, OmegaConf
|
8 |
|
9 |
-
from yolo.config.config import Config,
|
10 |
from yolo.tools.dataset_preparation import prepare_weight
|
11 |
from yolo.tools.drawer import draw_model
|
12 |
from yolo.utils.logging_utils import log_model_structure
|
@@ -22,9 +22,9 @@ class YOLO(nn.Module):
|
|
22 |
parameters, and any other relevant configuration details.
|
23 |
"""
|
24 |
|
25 |
-
def __init__(self, model_cfg:
|
26 |
super(YOLO, self).__init__()
|
27 |
-
self.num_classes =
|
28 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
29 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
30 |
self.build_model(model_cfg.model)
|
@@ -126,7 +126,7 @@ def create_model(cfg: Config) -> YOLO:
|
|
126 |
YOLO: An instance of the model defined by the given configuration.
|
127 |
"""
|
128 |
OmegaConf.set_struct(cfg.model, False)
|
129 |
-
model = YOLO(cfg.model
|
130 |
logger.info("✅ Success load model")
|
131 |
if cfg.weight:
|
132 |
if os.path.exists(cfg.weight):
|
|
|
6 |
from loguru import logger
|
7 |
from omegaconf import ListConfig, OmegaConf
|
8 |
|
9 |
+
from yolo.config.config import Config, ModelConfig, YOLOLayer
|
10 |
from yolo.tools.dataset_preparation import prepare_weight
|
11 |
from yolo.tools.drawer import draw_model
|
12 |
from yolo.utils.logging_utils import log_model_structure
|
|
|
22 |
parameters, and any other relevant configuration details.
|
23 |
"""
|
24 |
|
25 |
+
def __init__(self, model_cfg: ModelConfig):
|
26 |
super(YOLO, self).__init__()
|
27 |
+
self.num_classes = model_cfg.class_num
|
28 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
29 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
30 |
self.build_model(model_cfg.model)
|
|
|
126 |
YOLO: An instance of the model defined by the given configuration.
|
127 |
"""
|
128 |
OmegaConf.set_struct(cfg.model, False)
|
129 |
+
model = YOLO(cfg.model)
|
130 |
logger.info("✅ Success load model")
|
131 |
if cfg.weight:
|
132 |
if os.path.exists(cfg.weight):
|
yolo/tools/data_loader.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
from os import path
|
3 |
from queue import Empty, Queue
|
4 |
from threading import Event, Thread
|
5 |
-
from typing import Generator, List,
|
6 |
|
7 |
import cv2
|
8 |
import hydra
|
@@ -15,7 +15,7 @@ from torch import Tensor
|
|
15 |
from torch.utils.data import DataLoader, Dataset
|
16 |
from torchvision.transforms import functional as TF
|
17 |
|
18 |
-
from yolo.config.config import
|
19 |
from yolo.tools.data_augmentation import (
|
20 |
AugmentationComposer,
|
21 |
HorizontalFlip,
|
@@ -33,15 +33,15 @@ from yolo.utils.dataset_utils import (
|
|
33 |
|
34 |
|
35 |
class YoloDataset(Dataset):
|
36 |
-
def __init__(self,
|
37 |
-
augment_cfg =
|
38 |
-
|
39 |
-
|
40 |
|
41 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
42 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
43 |
self.transform.get_more_data = self.get_more_data
|
44 |
-
self.data = self.load_data(
|
45 |
|
46 |
def load_data(self, dataset_path, phase_name):
|
47 |
"""
|
@@ -161,16 +161,15 @@ class YoloDataset(Dataset):
|
|
161 |
|
162 |
|
163 |
class YoloDataLoader(DataLoader):
|
164 |
-
def __init__(self,
|
165 |
"""Initializes the YoloDataLoader with hydra-config files."""
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
super().__init__(
|
170 |
dataset,
|
171 |
batch_size=data_cfg.batch_size,
|
172 |
shuffle=data_cfg.shuffle,
|
173 |
-
num_workers=
|
174 |
pin_memory=data_cfg.pin_memory,
|
175 |
collate_fn=self.collate_fn,
|
176 |
)
|
@@ -193,31 +192,33 @@ class YoloDataLoader(DataLoader):
|
|
193 |
target_sizes = [item[1].size(0) for item in batch]
|
194 |
# TODO: Improve readability of these proccess
|
195 |
batch_targets = torch.zeros(batch_size, max(target_sizes), 5)
|
|
|
196 |
for idx, target_size in enumerate(target_sizes):
|
197 |
batch_targets[idx, :target_size] = batch[idx][1]
|
|
|
198 |
|
199 |
batch_images = torch.stack([item[0] for item in batch])
|
200 |
|
201 |
return batch_images, batch_targets
|
202 |
|
203 |
|
204 |
-
def create_dataloader(
|
205 |
-
if
|
206 |
-
return StreamDataLoader(
|
207 |
|
208 |
-
if
|
209 |
-
prepare_dataset(
|
210 |
|
211 |
-
return YoloDataLoader(
|
212 |
|
213 |
|
214 |
class StreamDataLoader:
|
215 |
-
def __init__(self,
|
216 |
-
self.source =
|
217 |
self.running = True
|
218 |
self.is_stream = isinstance(self.source, int) or self.source.lower().startswith("rtmp://")
|
219 |
|
220 |
-
self.transform = AugmentationComposer([],
|
221 |
self.stop_event = Event()
|
222 |
|
223 |
if self.is_stream:
|
@@ -301,19 +302,3 @@ class StreamDataLoader:
|
|
301 |
|
302 |
def __len__(self):
|
303 |
return self.queue.qsize() if not self.is_stream else 0
|
304 |
-
|
305 |
-
|
306 |
-
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
307 |
-
def main(cfg):
|
308 |
-
dataloader = create_dataloader(cfg)
|
309 |
-
draw_bboxes(*next(iter(dataloader)))
|
310 |
-
|
311 |
-
|
312 |
-
if __name__ == "__main__":
|
313 |
-
import sys
|
314 |
-
|
315 |
-
sys.path.append("./")
|
316 |
-
from utils.logging_utils import custom_logger
|
317 |
-
|
318 |
-
custom_logger()
|
319 |
-
main()
|
|
|
2 |
from os import path
|
3 |
from queue import Empty, Queue
|
4 |
from threading import Event, Thread
|
5 |
+
from typing import Generator, List, Tuple, Union
|
6 |
|
7 |
import cv2
|
8 |
import hydra
|
|
|
15 |
from torch.utils.data import DataLoader, Dataset
|
16 |
from torchvision.transforms import functional as TF
|
17 |
|
18 |
+
from yolo.config.config import DataConfig, DatasetConfig
|
19 |
from yolo.tools.data_augmentation import (
|
20 |
AugmentationComposer,
|
21 |
HorizontalFlip,
|
|
|
33 |
|
34 |
|
35 |
class YoloDataset(Dataset):
|
36 |
+
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
|
37 |
+
augment_cfg = data_cfg.data_augment
|
38 |
+
self.image_size = data_cfg.image_size
|
39 |
+
phase_name = dataset_cfg.get(phase, phase)
|
40 |
|
41 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
42 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
43 |
self.transform.get_more_data = self.get_more_data
|
44 |
+
self.data = self.load_data(dataset_cfg.path, phase_name)
|
45 |
|
46 |
def load_data(self, dataset_path, phase_name):
|
47 |
"""
|
|
|
161 |
|
162 |
|
163 |
class YoloDataLoader(DataLoader):
|
164 |
+
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
|
165 |
"""Initializes the YoloDataLoader with hydra-config files."""
|
166 |
+
dataset = YoloDataset(data_cfg, dataset_cfg, task)
|
167 |
+
self.image_size = data_cfg.image_size[0]
|
|
|
168 |
super().__init__(
|
169 |
dataset,
|
170 |
batch_size=data_cfg.batch_size,
|
171 |
shuffle=data_cfg.shuffle,
|
172 |
+
num_workers=data_cfg.cpu_num,
|
173 |
pin_memory=data_cfg.pin_memory,
|
174 |
collate_fn=self.collate_fn,
|
175 |
)
|
|
|
192 |
target_sizes = [item[1].size(0) for item in batch]
|
193 |
# TODO: Improve readability of these proccess
|
194 |
batch_targets = torch.zeros(batch_size, max(target_sizes), 5)
|
195 |
+
batch_targets[:, :, 0] = -1
|
196 |
for idx, target_size in enumerate(target_sizes):
|
197 |
batch_targets[idx, :target_size] = batch[idx][1]
|
198 |
+
batch_targets[:, :, 1:] *= self.image_size
|
199 |
|
200 |
batch_images = torch.stack([item[0] for item in batch])
|
201 |
|
202 |
return batch_images, batch_targets
|
203 |
|
204 |
|
205 |
+
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
|
206 |
+
if task == "inference":
|
207 |
+
return StreamDataLoader(data_cfg)
|
208 |
|
209 |
+
if dataset_cfg.auto_download:
|
210 |
+
prepare_dataset(dataset_cfg)
|
211 |
|
212 |
+
return YoloDataLoader(data_cfg, dataset_cfg, task)
|
213 |
|
214 |
|
215 |
class StreamDataLoader:
|
216 |
+
def __init__(self, data_cfg: DataConfig):
|
217 |
+
self.source = data_cfg.source
|
218 |
self.running = True
|
219 |
self.is_stream = isinstance(self.source, int) or self.source.lower().startswith("rtmp://")
|
220 |
|
221 |
+
self.transform = AugmentationComposer([], data_cfg.image_size)
|
222 |
self.stop_event = Event()
|
223 |
|
224 |
if self.is_stream:
|
|
|
302 |
|
303 |
def __len__(self):
|
304 |
return self.queue.qsize() if not self.is_stream else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yolo/tools/dataset_preparation.py
CHANGED
@@ -52,12 +52,12 @@ def check_files(directory, expected_count=None):
|
|
52 |
return len(files) == expected_count if expected_count is not None else bool(files)
|
53 |
|
54 |
|
55 |
-
def prepare_dataset(
|
56 |
"""
|
57 |
Prepares dataset by downloading and unzipping if necessary.
|
58 |
"""
|
59 |
-
data_dir =
|
60 |
-
for data_type, settings in
|
61 |
base_url = settings["base_url"]
|
62 |
for dataset_type, dataset_args in settings.items():
|
63 |
if dataset_type == "base_url":
|
|
|
52 |
return len(files) == expected_count if expected_count is not None else bool(files)
|
53 |
|
54 |
|
55 |
+
def prepare_dataset(dataset_cfg: DatasetConfig):
|
56 |
"""
|
57 |
Prepares dataset by downloading and unzipping if necessary.
|
58 |
"""
|
59 |
+
data_dir = dataset_cfg.path
|
60 |
+
for data_type, settings in dataset_cfg.auto_download.items():
|
61 |
base_url = settings["base_url"]
|
62 |
for dataset_type, dataset_args in settings.items():
|
63 |
if dataset_type == "base_url":
|
yolo/tools/loss_functions.py
CHANGED
@@ -75,14 +75,11 @@ class DFLoss(nn.Module):
|
|
75 |
class YOLOLoss:
|
76 |
def __init__(self, cfg: Config) -> None:
|
77 |
self.reg_max = cfg.model.anchor.reg_max
|
78 |
-
self.class_num = cfg.class_num
|
79 |
self.image_size = list(cfg.image_size)
|
80 |
self.strides = cfg.model.anchor.strides
|
81 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
82 |
|
83 |
-
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
|
84 |
-
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
85 |
-
|
86 |
self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device)
|
87 |
|
88 |
self.cls = BCELoss()
|
@@ -90,7 +87,7 @@ class YOLOLoss:
|
|
90 |
self.iou = BoxLoss()
|
91 |
|
92 |
self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
|
93 |
-
self.box_converter = AnchorBoxConverter(cfg, device)
|
94 |
|
95 |
def separate_anchor(self, anchors):
|
96 |
"""
|
@@ -134,7 +131,6 @@ class DualLoss:
|
|
134 |
self.cls_rate = cfg.task.loss.objective["BCELoss"]
|
135 |
|
136 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
|
137 |
-
targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
|
138 |
|
139 |
# TODO: Need Refactor this region, make it flexible!
|
140 |
predicts = divide_into_chunks(predicts[0], 2)
|
|
|
75 |
class YOLOLoss:
|
76 |
def __init__(self, cfg: Config) -> None:
|
77 |
self.reg_max = cfg.model.anchor.reg_max
|
78 |
+
self.class_num = cfg.model.class_num
|
79 |
self.image_size = list(cfg.image_size)
|
80 |
self.strides = cfg.model.anchor.strides
|
81 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
82 |
|
|
|
|
|
|
|
83 |
self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device)
|
84 |
|
85 |
self.cls = BCELoss()
|
|
|
87 |
self.iou = BoxLoss()
|
88 |
|
89 |
self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
|
90 |
+
self.box_converter = AnchorBoxConverter(cfg.model, self.image_size, device)
|
91 |
|
92 |
def separate_anchor(self, anchors):
|
93 |
"""
|
|
|
131 |
self.cls_rate = cfg.task.loss.objective["BCELoss"]
|
132 |
|
133 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
|
|
|
134 |
|
135 |
# TODO: Need Refactor this region, make it flexible!
|
136 |
predicts = divide_into_chunks(predicts[0], 2)
|
yolo/tools/solver.py
CHANGED
@@ -5,12 +5,12 @@ from torch import Tensor
|
|
5 |
# TODO: We may can't use CUDA?
|
6 |
from torch.cuda.amp import GradScaler, autocast
|
7 |
|
8 |
-
from yolo.config.config import Config, TrainConfig
|
9 |
from yolo.model.yolo import YOLO
|
10 |
-
from yolo.tools.data_loader import StreamDataLoader
|
11 |
from yolo.tools.drawer import draw_bboxes
|
12 |
from yolo.tools.loss_functions import get_loss_function
|
13 |
-
from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
|
14 |
from yolo.utils.logging_utils import ProgressTracker
|
15 |
from yolo.utils.model_utils import (
|
16 |
ExponentialMovingAverage,
|
@@ -27,7 +27,7 @@ class ModelTrainer:
|
|
27 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
28 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
29 |
self.loss_fn = get_loss_function(cfg)
|
30 |
-
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
31 |
self.num_epochs = cfg.task.epoch
|
32 |
|
33 |
if getattr(train_cfg.ema, "enabled", False):
|
|
|
5 |
# TODO: We may can't use CUDA?
|
6 |
from torch.cuda.amp import GradScaler, autocast
|
7 |
|
8 |
+
from yolo.config.config import Config, TrainConfig, ValidationConfig
|
9 |
from yolo.model.yolo import YOLO
|
10 |
+
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
11 |
from yolo.tools.drawer import draw_bboxes
|
12 |
from yolo.tools.loss_functions import get_loss_function
|
13 |
+
from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms, calculate_map
|
14 |
from yolo.utils.logging_utils import ProgressTracker
|
15 |
from yolo.utils.model_utils import (
|
16 |
ExponentialMovingAverage,
|
|
|
27 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
28 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
29 |
self.loss_fn = get_loss_function(cfg)
|
30 |
+
self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
|
31 |
self.num_epochs = cfg.task.epoch
|
32 |
|
33 |
if getattr(train_cfg.ema, "enabled", False):
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -7,7 +7,7 @@ from einops import rearrange
|
|
7 |
from torch import Tensor
|
8 |
from torchvision.ops import batched_nms
|
9 |
|
10 |
-
from yolo.config.config import
|
11 |
|
12 |
|
13 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
@@ -125,14 +125,12 @@ def generate_anchors(image_size: List[int], strides: List[int], device):
|
|
125 |
|
126 |
|
127 |
class AnchorBoxConverter:
|
128 |
-
def __init__(self,
|
129 |
-
self.reg_max =
|
130 |
-
self.class_num =
|
131 |
-
self.
|
132 |
-
|
133 |
-
|
134 |
-
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
135 |
-
self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device)
|
136 |
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
|
137 |
|
138 |
def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
|
@@ -255,7 +253,7 @@ class BoxMatcher:
|
|
255 |
"""
|
256 |
predict_cls, predict_bbox = predict.split(self.class_num, dim=-1) # B, HW x (C B) -> B x HW x C, B x HW x B
|
257 |
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
258 |
-
target_cls = target_cls.long()
|
259 |
|
260 |
# get valid matrix (each gt appear in which anchor grid)
|
261 |
grid_mask = self.get_valid_matrix(target_bbox)
|
|
|
7 |
from torch import Tensor
|
8 |
from torchvision.ops import batched_nms
|
9 |
|
10 |
+
from yolo.config.config import MatcherConfig, ModelConfig, NMSConfig
|
11 |
|
12 |
|
13 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
|
|
125 |
|
126 |
|
127 |
class AnchorBoxConverter:
|
128 |
+
def __init__(self, model_cfg: ModelConfig, image_size: List[int], device: torch.device) -> None:
|
129 |
+
self.reg_max = model_cfg.anchor.reg_max
|
130 |
+
self.class_num = model_cfg.class_num
|
131 |
+
self.strides = model_cfg.anchor.strides
|
132 |
+
|
133 |
+
self.anchors, self.scaler = generate_anchors(image_size, self.strides, device)
|
|
|
|
|
134 |
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
|
135 |
|
136 |
def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
|
|
|
253 |
"""
|
254 |
predict_cls, predict_bbox = predict.split(self.class_num, dim=-1) # B, HW x (C B) -> B x HW x C, B x HW x B
|
255 |
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
256 |
+
target_cls = target_cls.long().clamp(0)
|
257 |
|
258 |
# get valid matrix (each gt appear in which anchor grid)
|
259 |
grid_mask = self.get_valid_matrix(target_bbox)
|
yolo/utils/logging_utils.py
CHANGED
@@ -39,7 +39,7 @@ def custom_logger(quite: bool = False):
|
|
39 |
|
40 |
|
41 |
class ProgressTracker:
|
42 |
-
def __init__(self,
|
43 |
self.progress = Progress(
|
44 |
TextColumn("[progress.description]{task.description}"),
|
45 |
BarColumn(bar_width=None),
|
|
|
39 |
|
40 |
|
41 |
class ProgressTracker:
|
42 |
+
def __init__(self, exp_name: str, save_path: str, use_wandb: bool = False):
|
43 |
self.progress = Progress(
|
44 |
TextColumn("[progress.description]{task.description}"),
|
45 |
BarColumn(bar_width=None),
|
yolo/utils/model_utils.py
CHANGED
@@ -63,7 +63,7 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
|
|
63 |
if hasattr(schedule_cfg, "warmup"):
|
64 |
wepoch = schedule_cfg.warmup.epochs
|
65 |
lambda1 = lambda epoch: 0.1 + 0.9 * (epoch + 1 / wepoch) if epoch < wepoch else 1
|
66 |
-
lambda2 = lambda epoch: 10 - 9 * (epoch
|
67 |
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
|
68 |
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
|
69 |
return schedule
|
|
|
63 |
if hasattr(schedule_cfg, "warmup"):
|
64 |
wepoch = schedule_cfg.warmup.epochs
|
65 |
lambda1 = lambda epoch: 0.1 + 0.9 * (epoch + 1 / wepoch) if epoch < wepoch else 1
|
66 |
+
lambda2 = lambda epoch: 10 - 9 * (epoch / wepoch) if epoch < wepoch else 1
|
67 |
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
|
68 |
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
|
69 |
return schedule
|