Spaces:
Paused
Paused
File size: 8,331 Bytes
8b7211f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
# Ultralytics YOLO 🚀, GPL-3.0 license
from pathlib import Path
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
# Map head to model, trainer, validator, and predictor classes
MODEL_MAP = {
"classify": [
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
'yolo.TYPE.classify.ClassificationPredictor'],
"detect": [
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
'yolo.TYPE.detect.DetectionPredictor'],
"segment": [
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
'yolo.TYPE.segment.SegmentationPredictor']}
class YOLO:
"""
YOLO
A python interface which emulates a model-like behaviour by wrapping trainers.
"""
def __init__(self, model='yolov8n.yaml', type="v8") -> None:
"""
> Initializes the YOLO object.
Args:
model (str, Path): model to load or create
type (str): Type/version of models to use. Defaults to "v8".
"""
self.type = type
self.ModelClass = None # model class
self.TrainerClass = None # trainer class
self.ValidatorClass = None # validator class
self.PredictorClass = None # predictor class
self.model = None # model object
self.trainer = None # trainer object
self.task = None # task type
self.ckpt = None # if loaded from *.pt
self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object
# Load or create new YOLO model
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
def __call__(self, source, **kwargs):
return self.predict(source, **kwargs)
def _new(self, cfg: str, verbose=True):
"""
> Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
verbose (bool): display model info on load
"""
cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
self.task = guess_task_from_head(cfg_dict["head"][-1][-2])
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._guess_ops_from_task(self.task)
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
self.cfg = cfg
def _load(self, weights: str):
"""
> Initializes a new model and infers the task type from the model head.
Args:
weights (str): model checkpoint to be loaded
"""
self.model, self.ckpt = attempt_load_one_weight(weights)
self.ckpt_path = weights
self.task = self.model.args["task"]
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._guess_ops_from_task(self.task)
def reset(self):
"""
> Resets the model modules.
"""
for m in self.model.modules():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
for p in self.model.parameters():
p.requires_grad = True
def info(self, verbose=False):
"""
> Logs model info.
Args:
verbose (bool): Controls verbosity.
"""
self.model.info(verbose=verbose)
def fuse(self):
self.model.fuse()
@smart_inference_mode()
def predict(self, source, **kwargs):
"""
Visualize prediction.
Args:
source (str): Accepts all source types accepted by yolo
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = self.overrides.copy()
overrides["conf"] = 0.25
overrides.update(kwargs)
overrides["mode"] = "predict"
overrides["save"] = kwargs.get("save", False) # not save files by default
predictor = self.PredictorClass(overrides=overrides)
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
predictor.setup(model=self.model, source=source)
return predictor()
@smart_inference_mode()
def val(self, data=None, **kwargs):
"""
> Validate a model on a given dataset .
Args:
data (str): The dataset to validate on. Accepts all formats accepted by yolo
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
overrides = self.overrides.copy()
overrides.update(kwargs)
overrides["mode"] = "val"
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
args.data = data or args.data
args.task = self.task
validator = self.ValidatorClass(args=args)
validator(model=self.model)
@smart_inference_mode()
def export(self, **kwargs):
"""
> Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = self.overrides.copy()
overrides.update(kwargs)
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
args.task = self.task
exporter = Exporter(overrides=args)
exporter(model=self.model)
def train(self, **kwargs):
"""
> Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
"""
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get("cfg"):
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
overrides["task"] = self.task
overrides["mode"] = "train"
if not overrides.get("data"):
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
if overrides.get("resume"):
overrides["resume"] = self.ckpt_path
self.trainer = self.TrainerClass(overrides=overrides)
if not overrides.get("resume"): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
def to(self, device):
"""
> Sends the model to the given device.
Args:
device (str): device
"""
self.model.to(device)
def _guess_ops_from_task(self, task):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
# warning: eval is unsafe. Use with caution
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
return model_class, trainer_class, validator_class, predictor_class
@staticmethod
def _reset_ckpt_args(args):
args.pop("device", None)
args.pop("project", None)
args.pop("name", None)
args.pop("batch", None)
args.pop("epochs", None)
args.pop("cache", None)
args.pop("save_json", None)
|