henry000 commited on
Commit
0a3c9de
·
1 Parent(s): 360a2c0

🔧 [Move] model class num config out of modelyaml

Browse files
yolo/config/config.py CHANGED
@@ -25,7 +25,6 @@ class BlockConfig:
25
  @dataclass
26
  class ModelConfig:
27
  anchor: AnchorConfig
28
- class_num: int
29
  model: Dict[str, BlockConfig]
30
 
31
 
 
25
  @dataclass
26
  class ModelConfig:
27
  anchor: AnchorConfig
 
28
  model: Dict[str, BlockConfig]
29
 
30
 
yolo/config/model/v9-c.yaml CHANGED
@@ -2,8 +2,6 @@ anchor:
2
  reg_max: 16
3
  strides: [8, 16, 32]
4
 
5
- class_num: ${class_num}
6
-
7
  model:
8
  backbone:
9
  - Conv:
 
2
  reg_max: 16
3
  strides: [8, 16, 32]
4
 
 
 
5
  model:
6
  backbone:
7
  - Conv:
yolo/lazy.py CHANGED
@@ -25,7 +25,7 @@ def main(cfg: Config):
25
  model = FastModelLoader(cfg).load_model()
26
  device = torch.device(cfg.device)
27
  else:
28
- model = create_model(cfg.model, cfg.weight).to(device)
29
 
30
  if cfg.task.task == "train":
31
  trainer = ModelTrainer(cfg, model, save_path, device)
 
25
  model = FastModelLoader(cfg).load_model()
26
  device = torch.device(cfg.device)
27
  else:
28
+ model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight).to(device)
29
 
30
  if cfg.task.task == "train":
31
  trainer = ModelTrainer(cfg, model, save_path, device)
yolo/model/module.py CHANGED
@@ -93,13 +93,13 @@ class MultiheadDetection(nn.Module):
93
 
94
 
95
  class Anchor2Box(nn.Module):
96
- def __init__(self, reg_max, strides) -> None:
97
  super().__init__()
98
  self.reg_max = reg_max
99
  self.strides = strides
100
  # TODO: read by cfg!
101
  image_size = [640, 640]
102
- self.class_num = 80
103
  self.anchors, self.scaler = generate_anchors(image_size, self.strides)
104
  reverse_reg = torch.arange(self.reg_max, dtype=torch.float32)
105
  self.reverse_reg = nn.Parameter(reverse_reg, requires_grad=False)
@@ -117,7 +117,7 @@ class Anchor2Box(nn.Module):
117
  for pred in predicts:
118
  preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
119
  preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
120
- preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
121
  preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
122
 
123
  pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
 
93
 
94
 
95
  class Anchor2Box(nn.Module):
96
+ def __init__(self, reg_max, strides, num_classes: int) -> None:
97
  super().__init__()
98
  self.reg_max = reg_max
99
  self.strides = strides
100
  # TODO: read by cfg!
101
  image_size = [640, 640]
102
+ self.num_classes = num_classes
103
  self.anchors, self.scaler = generate_anchors(image_size, self.strides)
104
  reverse_reg = torch.arange(self.reg_max, dtype=torch.float32)
105
  self.reverse_reg = nn.Parameter(reverse_reg, requires_grad=False)
 
117
  for pred in predicts:
118
  preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
119
  preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
120
+ preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.num_classes), dim=-1)
121
  preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
122
 
123
  pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
yolo/model/yolo.py CHANGED
@@ -22,9 +22,9 @@ class YOLO(nn.Module):
22
  parameters, and any other relevant configuration details.
23
  """
24
 
25
- def __init__(self, model_cfg: ModelConfig):
26
  super(YOLO, self).__init__()
27
- self.num_classes = model_cfg.class_num
28
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
29
  self.model: List[YOLOLayer] = nn.ModuleList()
30
  self.build_model(model_cfg.model)
@@ -47,6 +47,7 @@ class YOLO(nn.Module):
47
  layer_args["in_channels"] = output_dim[source]
48
  if "Detection" in layer_type:
49
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
 
50
  layer_args["num_classes"] = self.num_classes
51
 
52
  # create layers
@@ -116,7 +117,7 @@ class YOLO(nn.Module):
116
  raise ValueError(f"Unsupported layer type: {layer_type}")
117
 
118
 
119
- def create_model(model_cfg: ModelConfig, weight_path: str) -> YOLO:
120
  """Constructs and returns a model from a Dictionary configuration file.
121
 
122
  Args:
@@ -126,7 +127,7 @@ def create_model(model_cfg: ModelConfig, weight_path: str) -> YOLO:
126
  YOLO: An instance of the model defined by the given configuration.
127
  """
128
  OmegaConf.set_struct(model_cfg, False)
129
- model = YOLO(model_cfg)
130
  logger.info("✅ Success load model")
131
  if weight_path:
132
  if os.path.exists(weight_path):
 
22
  parameters, and any other relevant configuration details.
23
  """
24
 
25
+ def __init__(self, model_cfg: ModelConfig, class_num: int = 80):
26
  super(YOLO, self).__init__()
27
+ self.num_classes = class_num
28
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
29
  self.model: List[YOLOLayer] = nn.ModuleList()
30
  self.build_model(model_cfg.model)
 
47
  layer_args["in_channels"] = output_dim[source]
48
  if "Detection" in layer_type:
49
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
50
+ if "Detection" in layer_type or "Anchor2Box" in layer_type:
51
  layer_args["num_classes"] = self.num_classes
52
 
53
  # create layers
 
117
  raise ValueError(f"Unsupported layer type: {layer_type}")
118
 
119
 
120
+ def create_model(model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt") -> YOLO:
121
  """Constructs and returns a model from a Dictionary configuration file.
122
 
123
  Args:
 
127
  YOLO: An instance of the model defined by the given configuration.
128
  """
129
  OmegaConf.set_struct(model_cfg, False)
130
+ model = YOLO(model_cfg, class_num)
131
  logger.info("✅ Success load model")
132
  if weight_path:
133
  if os.path.exists(weight_path):
yolo/tools/loss_functions.py CHANGED
@@ -70,7 +70,7 @@ class DFLoss(nn.Module):
70
  class YOLOLoss:
71
  def __init__(self, cfg: Config) -> None:
72
  self.reg_max = cfg.model.anchor.reg_max
73
- self.class_num = cfg.model.class_num
74
  self.image_size = list(cfg.image_size)
75
  self.strides = cfg.model.anchor.strides
76
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
70
  class YOLOLoss:
71
  def __init__(self, cfg: Config) -> None:
72
  self.reg_max = cfg.model.anchor.reg_max
73
+ self.class_num = cfg.class_num
74
  self.image_size = list(cfg.image_size)
75
  self.strides = cfg.model.anchor.strides
76
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
yolo/utils/deploy_utils.py CHANGED
@@ -28,7 +28,7 @@ class FastModelLoader:
28
  return self._load_onnx_model()
29
  elif self.compiler == "trt":
30
  return self._load_trt_model()
31
- return create_model(self.cfg)
32
 
33
  def _load_onnx_model(self):
34
  from onnxruntime import InferenceSession
@@ -53,7 +53,7 @@ class FastModelLoader:
53
  from onnxruntime import InferenceSession
54
  from torch.onnx import export
55
 
56
- model = create_model(self.cfg).eval()
57
  dummy_input = torch.ones((1, 3, *self.cfg.image_size))
58
  export(
59
  model,
@@ -81,7 +81,7 @@ class FastModelLoader:
81
  def _create_trt_model(self):
82
  from torch2trt import torch2trt
83
 
84
- model = create_model(self.cfg).eval()
85
  dummy_input = torch.ones((1, 3, *self.cfg.image_size))
86
  logger.info(f"♻️ Creating TensorRT model")
87
  model_trt = torch2trt(model, [dummy_input])
 
28
  return self._load_onnx_model()
29
  elif self.compiler == "trt":
30
  return self._load_trt_model()
31
+ return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight)
32
 
33
  def _load_onnx_model(self):
34
  from onnxruntime import InferenceSession
 
53
  from onnxruntime import InferenceSession
54
  from torch.onnx import export
55
 
56
+ model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
57
  dummy_input = torch.ones((1, 3, *self.cfg.image_size))
58
  export(
59
  model,
 
81
  def _create_trt_model(self):
82
  from torch2trt import torch2trt
83
 
84
+ model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
85
  dummy_input = torch.ones((1, 3, *self.cfg.image_size))
86
  logger.info(f"♻️ Creating TensorRT model")
87
  model_trt = torch2trt(model, [dummy_input])