Spaces:
Sleeping
Sleeping
File size: 7,008 Bytes
6e32a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import logging
import os
from abc import abstractmethod
import cv2
import numpy as np
import pandas as pd
import spacy
import torch
from tqdm import tqdm
from modules.utils import generate_heatmap
class BaseTester(object):
def __init__(self, model, criterion, metric_ftns, args):
self.args = args
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
self.logger = logging.getLogger(__name__)
# setup GPU device if available, move model into configured device
self.device, device_ids = self._prepare_device(args.n_gpu)
self.model = model.to(self.device)
if len(device_ids) > 1:
self.model = torch.nn.DataParallel(model, device_ids=device_ids)
self.criterion = criterion
self.metric_ftns = metric_ftns
self.epochs = self.args.epochs
self.save_dir = self.args.save_dir
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self._load_checkpoint(args.load)
@abstractmethod
def test(self):
raise NotImplementedError
@abstractmethod
def plot(self):
raise NotImplementedError
def _prepare_device(self, n_gpu_use):
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
self.logger.warning(
"Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
self.logger.warning(
"Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
n_gpu_use, n_gpu))
n_gpu_use = n_gpu
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
list_ids = list(range(n_gpu_use))
return device, list_ids
def _load_checkpoint(self, load_path):
load_path = str(load_path)
self.logger.info("Loading checkpoint: {} ...".format(load_path))
checkpoint = torch.load(load_path)
self.model.load_state_dict(checkpoint)
class Tester(BaseTester):
def __init__(self, model, criterion, metric_ftns, args, test_dataloader):
super(Tester, self).__init__(model, criterion, metric_ftns, args)
self.test_dataloader = test_dataloader
def test(self):
self.logger.info('Start to evaluate in the test set.')
self.model.eval()
log = dict()
with torch.no_grad():
test_gts, test_res = [], []
for batch_idx, (images_id, images, reports_ids, reports_masks, align_ids, align_masks) in enumerate(self.test_dataloader):
images, reports_ids, reports_masks, align_ids, align_masks = images.to(self.device), reports_ids.to(self.device), \
reports_masks.to(self.device), align_ids.to(self.device), align_masks.to(self.device)
output = self.model(reports_ids, align_ids, align_masks, images, mode='sample')
reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
test_res.extend(reports)
test_gts.extend(ground_truths)
test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
{i: [re] for i, re in enumerate(test_res)})
log.update(**{'test_' + k: v for k, v in test_met.items()})
print(log)
test_res, test_gts = pd.DataFrame(test_res), pd.DataFrame(test_gts)
test_res.to_csv(os.path.join(self.save_dir, "res.csv"), index=False, header=False)
test_gts.to_csv(os.path.join(self.save_dir, "gts.csv"), index=False, header=False)
return log
def plot(self):
assert self.args.batch_size == 1 and self.args.beam_size == 1
self.logger.info('Start to plot attention weights in the test set.')
os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True)
os.makedirs(os.path.join(self.save_dir, "attentions_entities"), exist_ok=True)
ner = spacy.load("en_core_sci_sm")
mean = torch.tensor((0.485, 0.456, 0.406))
std = torch.tensor((0.229, 0.224, 0.225))
mean = mean[:, None, None]
std = std[:, None, None]
self.model.eval()
with torch.no_grad():
for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
self.device), reports_masks.to(self.device)
output, _ = self.model(images, mode='sample')
image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy()
report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split()
char2word = [idx for word_idx, word in enumerate(report) for idx in [word_idx] * (len(word) + 1)][:-1]
attention_weights = self.model.encoder_decoder.attention_weights[:-1]
assert len(attention_weights) == len(report)
for word_idx, (attns, word) in enumerate(zip(attention_weights, report)):
for layer_idx, attn in enumerate(attns):
os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
"layer_{}".format(layer_idx)), exist_ok=True)
heatmap = generate_heatmap(image, attn.mean(1).squeeze())
cv2.imwrite(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
"layer_{}".format(layer_idx), "{:04d}_{}.png".format(word_idx, word)),
heatmap)
for ne_idx, ne in enumerate(ner(" ".join(report)).ents):
for layer_idx in range(len(attention_weights[0])):
os.makedirs(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
"layer_{}".format(layer_idx)), exist_ok=True)
attn = [attns[layer_idx] for attns in
attention_weights[char2word[ne.start_char]:char2word[ne.end_char] + 1]]
attn = np.concatenate(attn, axis=2)
heatmap = generate_heatmap(image, attn.mean(1).mean(1).squeeze())
cv2.imwrite(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
"layer_{}".format(layer_idx), "{:04d}_{}.png".format(ne_idx, ne)),
heatmap)
|