Spaces:
Sleeping
Sleeping
from dataset import * | |
from models.YoloV3Lightning import * | |
import utils | |
def init(model, basic_sanity_check=True, find_max_lr=True, train=True, **kwargs): | |
if basic_sanity_check: | |
validate_dataset() | |
sanity_check(model) | |
print("Set basic_sanity_check to False to proceed") | |
else: | |
if find_max_lr: | |
optimizer = kwargs.get('optimizer') | |
criterion = kwargs.get('criterion') | |
train_loader = kwargs.get('train_loader') | |
utils.find_lr(model, optimizer, criterion, train_loader) | |
print("Set find_max_lr to False to proceed further") | |
else: | |
train_loader = kwargs.get('train_loader') | |
val_loader = kwargs.get('test_loader') | |
if train: | |
trainer = pl.Trainer( | |
precision=16, | |
max_epochs=cfg.NUM_EPOCHS, | |
accelerator='gpu' | |
) | |
cargs = {} | |
if cfg.LOAD_MODEL: | |
cargs = dict(ckpt_path=cfg.CHECKPOINT_FILE) | |
trainer.fit(model, train_loader, val_loader, **cargs) | |
else: | |
ckpt_file = kwargs.get('ckpt_file') | |
if ckpt_file: | |
checkpoint = utils.load_model_from_checkpoint(cfg.DEVICE, file_name=ckpt_file) | |
model.load_state_dict(checkpoint['model'], strict=False) | |
#-- Printing samples | |
model.to(cfg.DEVICE) | |
model.eval() | |
cfg.IMG_DIR = cfg.DATASET + "/images/" | |
cfg.LABEL_DIR = cfg.DATASET + "/labels/" | |
eval_dataset = YOLODataset( | |
cfg.DATASET + "/test.csv", | |
transform=cfg.test_transforms, | |
S=[cfg.IMAGE_SIZE // 32, cfg.IMAGE_SIZE // 16, cfg.IMAGE_SIZE // 8], | |
img_dir=cfg.IMG_DIR, | |
label_dir=cfg.LABEL_DIR, | |
anchors=cfg.ANCHORS, | |
mosaic=False | |
) | |
eval_loader = DataLoader( | |
dataset=eval_dataset, | |
batch_size=cfg.BATCH_SIZE, | |
num_workers=cfg.NUM_WORKERS, | |
pin_memory=cfg.PIN_MEMORY, | |
shuffle=True, | |
drop_last=False, | |
) | |
scaled_anchors = ( | |
torch.tensor(cfg.ANCHORS) | |
* torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
) | |
scaled_anchors = scaled_anchors.to(cfg.DEVICE) | |
utils.plot_examples(model, eval_loader, 0.5, 0.6, scaled_anchors) | |
# -- Printing MAP | |
pred_boxes, true_boxes = utils.get_evaluation_bboxes( | |
eval_loader, | |
model, | |
iou_threshold=cfg.NMS_IOU_THRESH, | |
anchors=cfg.ANCHORS, | |
threshold=cfg.CONF_THRESHOLD, | |
) | |
mapval = utils.mean_average_precision( | |
pred_boxes, | |
true_boxes, | |
iou_threshold=cfg.MAP_IOU_THRESH, | |
box_format="midpoint", | |
num_classes=cfg.NUM_CLASSES, | |
) | |
print(f"MAP: {mapval.item()}") | |