import logging import os from abc import abstractmethod import cv2 import numpy as np import pandas as pd import spacy import torch from tqdm import tqdm from modules.utils import generate_heatmap class BaseTester(object): def __init__(self, model, criterion, metric_ftns, args): self.args = args logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) self.logger = logging.getLogger(__name__) # setup GPU device if available, move model into configured device self.device, device_ids = self._prepare_device(args.n_gpu) self.model = model.to(self.device) if len(device_ids) > 1: self.model = torch.nn.DataParallel(model, device_ids=device_ids) self.criterion = criterion self.metric_ftns = metric_ftns self.epochs = self.args.epochs self.save_dir = self.args.save_dir if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) self._load_checkpoint(args.load) @abstractmethod def test(self): raise NotImplementedError @abstractmethod def plot(self): raise NotImplementedError def _prepare_device(self, n_gpu_use): n_gpu = torch.cuda.device_count() if n_gpu_use > 0 and n_gpu == 0: self.logger.warning( "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") n_gpu_use = 0 if n_gpu_use > n_gpu: self.logger.warning( "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format( n_gpu_use, n_gpu)) n_gpu_use = n_gpu device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') list_ids = list(range(n_gpu_use)) return device, list_ids def _load_checkpoint(self, load_path): load_path = str(load_path) self.logger.info("Loading checkpoint: {} ...".format(load_path)) checkpoint = torch.load(load_path) self.model.load_state_dict(checkpoint) class Tester(BaseTester): def __init__(self, model, criterion, metric_ftns, args, test_dataloader): super(Tester, self).__init__(model, criterion, metric_ftns, args) self.test_dataloader = test_dataloader def test(self): self.logger.info('Start to evaluate in the test set.') self.model.eval() log = dict() with torch.no_grad(): test_gts, test_res = [], [] for batch_idx, (images_id, images, reports_ids, reports_masks, align_ids, align_masks) in enumerate(self.test_dataloader): images, reports_ids, reports_masks, align_ids, align_masks = images.to(self.device), reports_ids.to(self.device), \ reports_masks.to(self.device), align_ids.to(self.device), align_masks.to(self.device) output = self.model(reports_ids, align_ids, align_masks, images, mode='sample') reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) test_res.extend(reports) test_gts.extend(ground_truths) test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)}, {i: [re] for i, re in enumerate(test_res)}) log.update(**{'test_' + k: v for k, v in test_met.items()}) print(log) test_res, test_gts = pd.DataFrame(test_res), pd.DataFrame(test_gts) test_res.to_csv(os.path.join(self.save_dir, "res.csv"), index=False, header=False) test_gts.to_csv(os.path.join(self.save_dir, "gts.csv"), index=False, header=False) return log def plot(self): assert self.args.batch_size == 1 and self.args.beam_size == 1 self.logger.info('Start to plot attention weights in the test set.') os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True) os.makedirs(os.path.join(self.save_dir, "attentions_entities"), exist_ok=True) ner = spacy.load("en_core_sci_sm") mean = torch.tensor((0.485, 0.456, 0.406)) std = torch.tensor((0.229, 0.224, 0.225)) mean = mean[:, None, None] std = std[:, None, None] self.model.eval() with torch.no_grad(): for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)): images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( self.device), reports_masks.to(self.device) output, _ = self.model(images, mode='sample') image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy() report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split() char2word = [idx for word_idx, word in enumerate(report) for idx in [word_idx] * (len(word) + 1)][:-1] attention_weights = self.model.encoder_decoder.attention_weights[:-1] assert len(attention_weights) == len(report) for word_idx, (attns, word) in enumerate(zip(attention_weights, report)): for layer_idx, attn in enumerate(attns): os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx), "layer_{}".format(layer_idx)), exist_ok=True) heatmap = generate_heatmap(image, attn.mean(1).squeeze()) cv2.imwrite(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx), "layer_{}".format(layer_idx), "{:04d}_{}.png".format(word_idx, word)), heatmap) for ne_idx, ne in enumerate(ner(" ".join(report)).ents): for layer_idx in range(len(attention_weights[0])): os.makedirs(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx), "layer_{}".format(layer_idx)), exist_ok=True) attn = [attns[layer_idx] for attns in attention_weights[char2word[ne.start_char]:char2word[ne.end_char] + 1]] attn = np.concatenate(attn, axis=2) heatmap = generate_heatmap(image, attn.mean(1).mean(1).squeeze()) cv2.imwrite(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx), "layer_{}".format(layer_idx), "{:04d}_{}.png".format(ne_idx, ne)), heatmap)