piyushgrover commited on
Commit
91788b1
·
1 Parent(s): d246538

deleted main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -7
  2. main.py +0 -86
app.py CHANGED
@@ -2,21 +2,14 @@ import gradio as gr
2
  from typing import List
3
  import cv2
4
  import torch
5
- from torchvision import transforms
6
  import numpy as np
7
- from PIL import Image
8
- from pytorch_grad_cam import GradCAM
9
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
10
  from pytorch_grad_cam.utils.image import show_cam_on_image
11
- import io
12
  from models import YoloV3Lightning
13
  from utils import load_model_from_checkpoint
14
  import utils
15
  import config as cfg
16
  import matplotlib.pyplot as plt
17
  import matplotlib.patches as patches
18
- from dataset import YOLODataset
19
- from torch.utils.data import Dataset, DataLoader
20
  from grad_cam import YoloGradCAM
21
 
22
  device = torch.device('cpu')
 
2
  from typing import List
3
  import cv2
4
  import torch
 
5
  import numpy as np
 
 
 
6
  from pytorch_grad_cam.utils.image import show_cam_on_image
 
7
  from models import YoloV3Lightning
8
  from utils import load_model_from_checkpoint
9
  import utils
10
  import config as cfg
11
  import matplotlib.pyplot as plt
12
  import matplotlib.patches as patches
 
 
13
  from grad_cam import YoloGradCAM
14
 
15
  device = torch.device('cpu')
main.py DELETED
@@ -1,86 +0,0 @@
1
- from dataset import *
2
- from models.YoloV3Lightning import *
3
- import utils
4
-
5
- def init(model, basic_sanity_check=True, find_max_lr=True, train=True, **kwargs):
6
- if basic_sanity_check:
7
- validate_dataset()
8
- sanity_check(model)
9
- print("Set basic_sanity_check to False to proceed")
10
- else:
11
- if find_max_lr:
12
- optimizer = kwargs.get('optimizer')
13
- criterion = kwargs.get('criterion')
14
- train_loader = kwargs.get('train_loader')
15
- utils.find_lr(model, optimizer, criterion, train_loader)
16
- print("Set find_max_lr to False to proceed further")
17
- else:
18
-
19
- train_loader = kwargs.get('train_loader')
20
- val_loader = kwargs.get('test_loader')
21
-
22
- if train:
23
- trainer = pl.Trainer(
24
- precision=16,
25
- max_epochs=cfg.NUM_EPOCHS,
26
- accelerator='gpu'
27
- )
28
-
29
- cargs = {}
30
- if cfg.LOAD_MODEL:
31
- cargs = dict(ckpt_path=cfg.CHECKPOINT_FILE)
32
-
33
- trainer.fit(model, train_loader, val_loader, **cargs)
34
- else:
35
- ckpt_file = kwargs.get('ckpt_file')
36
- if ckpt_file:
37
- checkpoint = utils.load_model_from_checkpoint(cfg.DEVICE, file_name=ckpt_file)
38
- model.load_state_dict(checkpoint['model'], strict=False)
39
-
40
- #-- Printing samples
41
- model.to(cfg.DEVICE)
42
- model.eval()
43
- cfg.IMG_DIR = cfg.DATASET + "/images/"
44
- cfg.LABEL_DIR = cfg.DATASET + "/labels/"
45
- eval_dataset = YOLODataset(
46
- cfg.DATASET + "/test.csv",
47
- transform=cfg.test_transforms,
48
- S=[cfg.IMAGE_SIZE // 32, cfg.IMAGE_SIZE // 16, cfg.IMAGE_SIZE // 8],
49
- img_dir=cfg.IMG_DIR,
50
- label_dir=cfg.LABEL_DIR,
51
- anchors=cfg.ANCHORS,
52
- mosaic=False
53
- )
54
- eval_loader = DataLoader(
55
- dataset=eval_dataset,
56
- batch_size=cfg.BATCH_SIZE,
57
- num_workers=cfg.NUM_WORKERS,
58
- pin_memory=cfg.PIN_MEMORY,
59
- shuffle=True,
60
- drop_last=False,
61
- )
62
-
63
- scaled_anchors = (
64
- torch.tensor(cfg.ANCHORS)
65
- * torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
66
- )
67
- scaled_anchors = scaled_anchors.to(cfg.DEVICE)
68
-
69
- utils.plot_examples(model, eval_loader, 0.5, 0.6, scaled_anchors)
70
-
71
- # -- Printing MAP
72
- pred_boxes, true_boxes = utils.get_evaluation_bboxes(
73
- eval_loader,
74
- model,
75
- iou_threshold=cfg.NMS_IOU_THRESH,
76
- anchors=cfg.ANCHORS,
77
- threshold=cfg.CONF_THRESHOLD,
78
- )
79
- mapval = utils.mean_average_precision(
80
- pred_boxes,
81
- true_boxes,
82
- iou_threshold=cfg.MAP_IOU_THRESH,
83
- box_format="midpoint",
84
- num_classes=cfg.NUM_CLASSES,
85
- )
86
- print(f"MAP: {mapval.item()}")