yolov3 / main.py
piyushgrover's picture
added space app files
5bfab10
raw
history blame
3.17 kB
from dataset import *
from models.YoloV3Lightning import *
import utils
def init(model, basic_sanity_check=True, find_max_lr=True, train=True, **kwargs):
if basic_sanity_check:
validate_dataset()
sanity_check(model)
print("Set basic_sanity_check to False to proceed")
else:
if find_max_lr:
optimizer = kwargs.get('optimizer')
criterion = kwargs.get('criterion')
train_loader = kwargs.get('train_loader')
utils.find_lr(model, optimizer, criterion, train_loader)
print("Set find_max_lr to False to proceed further")
else:
train_loader = kwargs.get('train_loader')
val_loader = kwargs.get('test_loader')
if train:
trainer = pl.Trainer(
precision=16,
max_epochs=cfg.NUM_EPOCHS,
accelerator='gpu'
)
cargs = {}
if cfg.LOAD_MODEL:
cargs = dict(ckpt_path=cfg.CHECKPOINT_FILE)
trainer.fit(model, train_loader, val_loader, **cargs)
else:
ckpt_file = kwargs.get('ckpt_file')
if ckpt_file:
checkpoint = utils.load_model_from_checkpoint(cfg.DEVICE, file_name=ckpt_file)
model.load_state_dict(checkpoint['model'], strict=False)
#-- Printing samples
model.to(cfg.DEVICE)
model.eval()
cfg.IMG_DIR = cfg.DATASET + "/images/"
cfg.LABEL_DIR = cfg.DATASET + "/labels/"
eval_dataset = YOLODataset(
cfg.DATASET + "/test.csv",
transform=cfg.test_transforms,
S=[cfg.IMAGE_SIZE // 32, cfg.IMAGE_SIZE // 16, cfg.IMAGE_SIZE // 8],
img_dir=cfg.IMG_DIR,
label_dir=cfg.LABEL_DIR,
anchors=cfg.ANCHORS,
mosaic=False
)
eval_loader = DataLoader(
dataset=eval_dataset,
batch_size=cfg.BATCH_SIZE,
num_workers=cfg.NUM_WORKERS,
pin_memory=cfg.PIN_MEMORY,
shuffle=True,
drop_last=False,
)
scaled_anchors = (
torch.tensor(cfg.ANCHORS)
* torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
)
scaled_anchors = scaled_anchors.to(cfg.DEVICE)
utils.plot_examples(model, eval_loader, 0.5, 0.6, scaled_anchors)
# -- Printing MAP
pred_boxes, true_boxes = utils.get_evaluation_bboxes(
eval_loader,
model,
iou_threshold=cfg.NMS_IOU_THRESH,
anchors=cfg.ANCHORS,
threshold=cfg.CONF_THRESHOLD,
)
mapval = utils.mean_average_precision(
pred_boxes,
true_boxes,
iou_threshold=cfg.MAP_IOU_THRESH,
box_format="midpoint",
num_classes=cfg.NUM_CLASSES,
)
print(f"MAP: {mapval.item()}")