File size: 3,168 Bytes
5bfab10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from dataset import *
from models.YoloV3Lightning import *
import utils

def init(model, basic_sanity_check=True, find_max_lr=True, train=True, **kwargs):
    if basic_sanity_check:
        validate_dataset()
        sanity_check(model)
        print("Set basic_sanity_check to False to proceed")
    else:
        if find_max_lr:
            optimizer = kwargs.get('optimizer')
            criterion = kwargs.get('criterion')
            train_loader = kwargs.get('train_loader')
            utils.find_lr(model, optimizer, criterion, train_loader)
            print("Set find_max_lr to False to proceed further")
        else:

            train_loader = kwargs.get('train_loader')
            val_loader = kwargs.get('test_loader')

            if train:
                trainer = pl.Trainer(
                    precision=16,
                    max_epochs=cfg.NUM_EPOCHS,
                    accelerator='gpu'
                )

                cargs = {}
                if cfg.LOAD_MODEL:
                    cargs = dict(ckpt_path=cfg.CHECKPOINT_FILE)

                trainer.fit(model, train_loader, val_loader, **cargs)
            else:
                ckpt_file = kwargs.get('ckpt_file')
                if ckpt_file:
                    checkpoint = utils.load_model_from_checkpoint(cfg.DEVICE, file_name=ckpt_file)
                    model.load_state_dict(checkpoint['model'], strict=False)

            #-- Printing samples
            model.to(cfg.DEVICE)
            model.eval()
            cfg.IMG_DIR = cfg.DATASET + "/images/"
            cfg.LABEL_DIR = cfg.DATASET + "/labels/"
            eval_dataset = YOLODataset(
                cfg.DATASET + "/test.csv",
                transform=cfg.test_transforms,
                S=[cfg.IMAGE_SIZE // 32, cfg.IMAGE_SIZE // 16, cfg.IMAGE_SIZE // 8],
                img_dir=cfg.IMG_DIR,
                label_dir=cfg.LABEL_DIR,
                anchors=cfg.ANCHORS,
                mosaic=False
            )
            eval_loader = DataLoader(
                dataset=eval_dataset,
                batch_size=cfg.BATCH_SIZE,
                num_workers=cfg.NUM_WORKERS,
                pin_memory=cfg.PIN_MEMORY,
                shuffle=True,
                drop_last=False,
            )

            scaled_anchors = (
                    torch.tensor(cfg.ANCHORS)
                    * torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
            )
            scaled_anchors = scaled_anchors.to(cfg.DEVICE)

            utils.plot_examples(model, eval_loader, 0.5, 0.6, scaled_anchors)

            # -- Printing MAP
            pred_boxes, true_boxes = utils.get_evaluation_bboxes(
                eval_loader,
                model,
                iou_threshold=cfg.NMS_IOU_THRESH,
                anchors=cfg.ANCHORS,
                threshold=cfg.CONF_THRESHOLD,
            )
            mapval = utils.mean_average_precision(
                pred_boxes,
                true_boxes,
                iou_threshold=cfg.MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=cfg.NUM_CLASSES,
            )
            print(f"MAP: {mapval.item()}")