henry000 commited on
Commit
682daad
Β·
2 Parent(s): 56b217b d58a9b6

πŸ”€ [Merge] branch 'SETUP' into MODELv2

Browse files
yolo/lazy.py CHANGED
@@ -13,13 +13,12 @@ from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
- from yolo.utils.logging_utils import custom_logger, validate_log_directory
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
- custom_logger()
22
- save_path = validate_log_directory(cfg, exp_name=cfg.name)
23
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
24
  device = torch.device(cfg.device)
25
  if getattr(cfg.task, "fast_inference", False):
@@ -31,11 +30,11 @@ def main(cfg: Config):
31
  vec2box = Vec2Box(model, cfg.image_size, device)
32
 
33
  if cfg.task.task == "train":
34
- trainer = ModelTrainer(cfg, model, vec2box, save_path, device)
35
  trainer.solve(dataloader)
36
 
37
  if cfg.task.task == "inference":
38
- tester = ModelTester(cfg, model, vec2box, save_path, device)
39
  tester.solve(dataloader)
40
 
41
 
 
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
+ from yolo.utils.logging_utils import ProgressLogger
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
+ progress = ProgressLogger(cfg, exp_name=cfg.name)
 
22
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
23
  device = torch.device(cfg.device)
24
  if getattr(cfg.task, "fast_inference", False):
 
30
  vec2box = Vec2Box(model, cfg.image_size, device)
31
 
32
  if cfg.task.task == "train":
33
+ trainer = ModelTrainer(cfg, model, vec2box, progress, device)
34
  trainer.solve(dataloader)
35
 
36
  if cfg.task.task == "inference":
37
+ tester = ModelTester(cfg, model, vec2box, progress, device)
38
  tester.solve(dataloader)
39
 
40
 
yolo/tools/solver.py CHANGED
@@ -14,7 +14,7 @@ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
14
  from yolo.tools.drawer import draw_bboxes
15
  from yolo.tools.loss_functions import create_loss_function
16
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
17
- from yolo.utils.logging_utils import ProgressTracker
18
  from yolo.utils.model_utils import (
19
  ExponentialMovingAverage,
20
  create_optimizer,
@@ -23,7 +23,7 @@ from yolo.utils.model_utils import (
23
 
24
 
25
  class ModelTrainer:
26
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
27
  train_cfg: TrainConfig = cfg.task
28
  self.model = model
29
  self.vec2box = vec2box
@@ -31,11 +31,11 @@ class ModelTrainer:
31
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
32
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
33
  self.loss_fn = create_loss_function(cfg, vec2box)
34
- self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
35
  self.num_epochs = cfg.task.epoch
36
 
37
  self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
38
- self.validator = ModelValidator(cfg.task.validation, model, vec2box, save_path, device, self.progress)
39
 
40
  if getattr(train_cfg.ema, "enabled", False):
41
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
@@ -102,14 +102,15 @@ class ModelTrainer:
102
 
103
 
104
  class ModelTester:
105
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
106
  self.model = model
107
  self.device = device
108
  self.vec2box = vec2box
109
- self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
110
 
111
  self.nms = cfg.task.nms
112
- self.save_path = save_path
 
113
  self.save_predict = getattr(cfg.task, "save_predict", None)
114
  self.idx2label = cfg.class_list
115
 
@@ -164,16 +165,13 @@ class ModelValidator:
164
  validation_cfg: ValidationConfig,
165
  model: YOLO,
166
  vec2box: Vec2Box,
167
- save_path: str,
168
  device,
169
- # TODO: think Progress?
170
- progress: ProgressTracker,
171
  ):
172
  self.model = model
173
  self.vec2box = vec2box
174
  self.device = device
175
  self.progress = progress
176
- self.save_path = save_path
177
 
178
  self.nms = validation_cfg.nms
179
 
 
14
  from yolo.tools.drawer import draw_bboxes
15
  from yolo.tools.loss_functions import create_loss_function
16
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
17
+ from yolo.utils.logging_utils import ProgressLogger
18
  from yolo.utils.model_utils import (
19
  ExponentialMovingAverage,
20
  create_optimizer,
 
23
 
24
 
25
  class ModelTrainer:
26
+ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
27
  train_cfg: TrainConfig = cfg.task
28
  self.model = model
29
  self.vec2box = vec2box
 
31
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
32
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
33
  self.loss_fn = create_loss_function(cfg, vec2box)
34
+ self.progress = progress
35
  self.num_epochs = cfg.task.epoch
36
 
37
  self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
38
+ self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device, self.progress)
39
 
40
  if getattr(train_cfg.ema, "enabled", False):
41
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
 
102
 
103
 
104
  class ModelTester:
105
+ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
106
  self.model = model
107
  self.device = device
108
  self.vec2box = vec2box
109
+ self.progress = progress
110
 
111
  self.nms = cfg.task.nms
112
+ self.save_path = os.path.join(progress.save_path, "images")
113
+ os.makedirs(self.save_path, exist_ok=True)
114
  self.save_predict = getattr(cfg.task, "save_predict", None)
115
  self.idx2label = cfg.class_list
116
 
 
165
  validation_cfg: ValidationConfig,
166
  model: YOLO,
167
  vec2box: Vec2Box,
 
168
  device,
169
+ progress: ProgressLogger,
 
170
  ):
171
  self.model = model
172
  self.vec2box = vec2box
173
  self.device = device
174
  self.progress = progress
 
175
 
176
  self.nms = validation_cfg.nms
177
 
yolo/utils/logging_utils.py CHANGED
@@ -38,15 +38,18 @@ def custom_logger(quite: bool = False):
38
  )
39
 
40
 
41
- class ProgressTracker:
42
- def __init__(self, exp_name: str, save_path: str, use_wandb: bool = False):
 
 
 
43
  self.progress = Progress(
44
  TextColumn("[progress.description]{task.description}"),
45
  BarColumn(bar_width=None),
46
  TextColumn("{task.completed:.0f}/{task.total:.0f}"),
47
  TimeRemainingColumn(),
48
  )
49
- self.use_wandb = use_wandb
50
  if self.use_wandb:
51
  wandb.errors.term._log = custom_wandb_log
52
  self.wandb = wandb.init(
 
38
  )
39
 
40
 
41
+ class ProgressLogger:
42
+ def __init__(self, cfg: Config, exp_name: str):
43
+ custom_logger(getattr(cfg, "quite", False))
44
+ self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
45
+
46
  self.progress = Progress(
47
  TextColumn("[progress.description]{task.description}"),
48
  BarColumn(bar_width=None),
49
  TextColumn("{task.completed:.0f}/{task.total:.0f}"),
50
  TimeRemainingColumn(),
51
  )
52
+ self.use_wandb = cfg.use_wandb
53
  if self.use_wandb:
54
  wandb.errors.term._log = custom_wandb_log
55
  self.wandb = wandb.init(