🔧 [Move] model class num config out of modelyaml
Browse files- yolo/config/config.py +0 -1
- yolo/config/model/v9-c.yaml +0 -2
- yolo/lazy.py +1 -1
- yolo/model/module.py +3 -3
- yolo/model/yolo.py +5 -4
- yolo/tools/loss_functions.py +1 -1
- yolo/utils/deploy_utils.py +3 -3
yolo/config/config.py
CHANGED
@@ -25,7 +25,6 @@ class BlockConfig:
|
|
25 |
@dataclass
|
26 |
class ModelConfig:
|
27 |
anchor: AnchorConfig
|
28 |
-
class_num: int
|
29 |
model: Dict[str, BlockConfig]
|
30 |
|
31 |
|
|
|
25 |
@dataclass
|
26 |
class ModelConfig:
|
27 |
anchor: AnchorConfig
|
|
|
28 |
model: Dict[str, BlockConfig]
|
29 |
|
30 |
|
yolo/config/model/v9-c.yaml
CHANGED
@@ -2,8 +2,6 @@ anchor:
|
|
2 |
reg_max: 16
|
3 |
strides: [8, 16, 32]
|
4 |
|
5 |
-
class_num: ${class_num}
|
6 |
-
|
7 |
model:
|
8 |
backbone:
|
9 |
- Conv:
|
|
|
2 |
reg_max: 16
|
3 |
strides: [8, 16, 32]
|
4 |
|
|
|
|
|
5 |
model:
|
6 |
backbone:
|
7 |
- Conv:
|
yolo/lazy.py
CHANGED
@@ -25,7 +25,7 @@ def main(cfg: Config):
|
|
25 |
model = FastModelLoader(cfg).load_model()
|
26 |
device = torch.device(cfg.device)
|
27 |
else:
|
28 |
-
model = create_model(cfg.model, cfg.weight).to(device)
|
29 |
|
30 |
if cfg.task.task == "train":
|
31 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
|
|
25 |
model = FastModelLoader(cfg).load_model()
|
26 |
device = torch.device(cfg.device)
|
27 |
else:
|
28 |
+
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight).to(device)
|
29 |
|
30 |
if cfg.task.task == "train":
|
31 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
yolo/model/module.py
CHANGED
@@ -93,13 +93,13 @@ class MultiheadDetection(nn.Module):
|
|
93 |
|
94 |
|
95 |
class Anchor2Box(nn.Module):
|
96 |
-
def __init__(self, reg_max, strides) -> None:
|
97 |
super().__init__()
|
98 |
self.reg_max = reg_max
|
99 |
self.strides = strides
|
100 |
# TODO: read by cfg!
|
101 |
image_size = [640, 640]
|
102 |
-
self.
|
103 |
self.anchors, self.scaler = generate_anchors(image_size, self.strides)
|
104 |
reverse_reg = torch.arange(self.reg_max, dtype=torch.float32)
|
105 |
self.reverse_reg = nn.Parameter(reverse_reg, requires_grad=False)
|
@@ -117,7 +117,7 @@ class Anchor2Box(nn.Module):
|
|
117 |
for pred in predicts:
|
118 |
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
119 |
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
120 |
-
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.
|
121 |
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
122 |
|
123 |
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
|
|
93 |
|
94 |
|
95 |
class Anchor2Box(nn.Module):
|
96 |
+
def __init__(self, reg_max, strides, num_classes: int) -> None:
|
97 |
super().__init__()
|
98 |
self.reg_max = reg_max
|
99 |
self.strides = strides
|
100 |
# TODO: read by cfg!
|
101 |
image_size = [640, 640]
|
102 |
+
self.num_classes = num_classes
|
103 |
self.anchors, self.scaler = generate_anchors(image_size, self.strides)
|
104 |
reverse_reg = torch.arange(self.reg_max, dtype=torch.float32)
|
105 |
self.reverse_reg = nn.Parameter(reverse_reg, requires_grad=False)
|
|
|
117 |
for pred in predicts:
|
118 |
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
119 |
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
120 |
+
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.num_classes), dim=-1)
|
121 |
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
122 |
|
123 |
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
yolo/model/yolo.py
CHANGED
@@ -22,9 +22,9 @@ class YOLO(nn.Module):
|
|
22 |
parameters, and any other relevant configuration details.
|
23 |
"""
|
24 |
|
25 |
-
def __init__(self, model_cfg: ModelConfig):
|
26 |
super(YOLO, self).__init__()
|
27 |
-
self.num_classes =
|
28 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
29 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
30 |
self.build_model(model_cfg.model)
|
@@ -47,6 +47,7 @@ class YOLO(nn.Module):
|
|
47 |
layer_args["in_channels"] = output_dim[source]
|
48 |
if "Detection" in layer_type:
|
49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
|
|
50 |
layer_args["num_classes"] = self.num_classes
|
51 |
|
52 |
# create layers
|
@@ -116,7 +117,7 @@ class YOLO(nn.Module):
|
|
116 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
117 |
|
118 |
|
119 |
-
def create_model(model_cfg: ModelConfig, weight_path: str) -> YOLO:
|
120 |
"""Constructs and returns a model from a Dictionary configuration file.
|
121 |
|
122 |
Args:
|
@@ -126,7 +127,7 @@ def create_model(model_cfg: ModelConfig, weight_path: str) -> YOLO:
|
|
126 |
YOLO: An instance of the model defined by the given configuration.
|
127 |
"""
|
128 |
OmegaConf.set_struct(model_cfg, False)
|
129 |
-
model = YOLO(model_cfg)
|
130 |
logger.info("✅ Success load model")
|
131 |
if weight_path:
|
132 |
if os.path.exists(weight_path):
|
|
|
22 |
parameters, and any other relevant configuration details.
|
23 |
"""
|
24 |
|
25 |
+
def __init__(self, model_cfg: ModelConfig, class_num: int = 80):
|
26 |
super(YOLO, self).__init__()
|
27 |
+
self.num_classes = class_num
|
28 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
29 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
30 |
self.build_model(model_cfg.model)
|
|
|
47 |
layer_args["in_channels"] = output_dim[source]
|
48 |
if "Detection" in layer_type:
|
49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
50 |
+
if "Detection" in layer_type or "Anchor2Box" in layer_type:
|
51 |
layer_args["num_classes"] = self.num_classes
|
52 |
|
53 |
# create layers
|
|
|
117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
118 |
|
119 |
|
120 |
+
def create_model(model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt") -> YOLO:
|
121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
122 |
|
123 |
Args:
|
|
|
127 |
YOLO: An instance of the model defined by the given configuration.
|
128 |
"""
|
129 |
OmegaConf.set_struct(model_cfg, False)
|
130 |
+
model = YOLO(model_cfg, class_num)
|
131 |
logger.info("✅ Success load model")
|
132 |
if weight_path:
|
133 |
if os.path.exists(weight_path):
|
yolo/tools/loss_functions.py
CHANGED
@@ -70,7 +70,7 @@ class DFLoss(nn.Module):
|
|
70 |
class YOLOLoss:
|
71 |
def __init__(self, cfg: Config) -> None:
|
72 |
self.reg_max = cfg.model.anchor.reg_max
|
73 |
-
self.class_num = cfg.
|
74 |
self.image_size = list(cfg.image_size)
|
75 |
self.strides = cfg.model.anchor.strides
|
76 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
70 |
class YOLOLoss:
|
71 |
def __init__(self, cfg: Config) -> None:
|
72 |
self.reg_max = cfg.model.anchor.reg_max
|
73 |
+
self.class_num = cfg.class_num
|
74 |
self.image_size = list(cfg.image_size)
|
75 |
self.strides = cfg.model.anchor.strides
|
76 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
yolo/utils/deploy_utils.py
CHANGED
@@ -28,7 +28,7 @@ class FastModelLoader:
|
|
28 |
return self._load_onnx_model()
|
29 |
elif self.compiler == "trt":
|
30 |
return self._load_trt_model()
|
31 |
-
return create_model(self.cfg)
|
32 |
|
33 |
def _load_onnx_model(self):
|
34 |
from onnxruntime import InferenceSession
|
@@ -53,7 +53,7 @@ class FastModelLoader:
|
|
53 |
from onnxruntime import InferenceSession
|
54 |
from torch.onnx import export
|
55 |
|
56 |
-
model = create_model(self.cfg).eval()
|
57 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
58 |
export(
|
59 |
model,
|
@@ -81,7 +81,7 @@ class FastModelLoader:
|
|
81 |
def _create_trt_model(self):
|
82 |
from torch2trt import torch2trt
|
83 |
|
84 |
-
model = create_model(self.cfg).eval()
|
85 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
86 |
logger.info(f"♻️ Creating TensorRT model")
|
87 |
model_trt = torch2trt(model, [dummy_input])
|
|
|
28 |
return self._load_onnx_model()
|
29 |
elif self.compiler == "trt":
|
30 |
return self._load_trt_model()
|
31 |
+
return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight)
|
32 |
|
33 |
def _load_onnx_model(self):
|
34 |
from onnxruntime import InferenceSession
|
|
|
53 |
from onnxruntime import InferenceSession
|
54 |
from torch.onnx import export
|
55 |
|
56 |
+
model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
|
57 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
58 |
export(
|
59 |
model,
|
|
|
81 |
def _create_trt_model(self):
|
82 |
from torch2trt import torch2trt
|
83 |
|
84 |
+
model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
|
85 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
86 |
logger.info(f"♻️ Creating TensorRT model")
|
87 |
model_trt = torch2trt(model, [dummy_input])
|