import os from tqdm import tqdm import h5py import argparse # Import saliency methods and models from misc_functions import * from ViT_explanation_generator import Baselines, LRP from ViT_new import vit_base_patch16_224 from ViT_LRP import vit_base_patch16_224 as vit_LRP from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP from torchvision.datasets import ImageNet def normalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): dtype = tensor.dtype mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device) tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) return tensor def compute_saliency_and_save(args): first = True with h5py.File(os.path.join(args.method_dir, 'results.hdf5'), 'a') as f: data_cam = f.create_dataset('vis', (1, 1, 224, 224), maxshape=(None, 1, 224, 224), dtype=np.float32, compression="gzip") data_image = f.create_dataset('image', (1, 3, 224, 224), maxshape=(None, 3, 224, 224), dtype=np.float32, compression="gzip") data_target = f.create_dataset('target', (1,), maxshape=(None,), dtype=np.int32, compression="gzip") for batch_idx, (data, target) in enumerate(tqdm(sample_loader)): if first: first = False data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0) data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0) data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0) else: data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0) data_image.resize(data_image.shape[0] + data.shape[0], axis=0) data_target.resize(data_target.shape[0] + data.shape[0], axis=0) # Add data data_image[-data.shape[0]:] = data.data.cpu().numpy() data_target[-data.shape[0]:] = target.data.cpu().numpy() target = target.to(device) data = normalize(data) data = data.to(device) data.requires_grad_() index = None if args.vis_class == 'target': index = target if args.method == 'rollout': Res = baselines.generate_rollout(data, start_layer=1).reshape(data.shape[0], 1, 14, 14) # Res = Res - Res.mean() elif args.method == 'lrp': Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape(data.shape[0], 1, 14, 14) # Res = Res - Res.mean() elif args.method == 'transformer_attribution': Res = lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14) # Res = Res - Res.mean() elif args.method == 'full_lrp': Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape(data.shape[0], 1, 224, 224) # Res = Res - Res.mean() elif args.method == 'lrp_last_layer': Res = orig_lrp.generate_LRP(data, method="last_layer", is_ablation=args.is_ablation, index=index) \ .reshape(data.shape[0], 1, 14, 14) # Res = Res - Res.mean() elif args.method == 'attn_last_layer': Res = lrp.generate_LRP(data, method="last_layer_attn", is_ablation=args.is_ablation) \ .reshape(data.shape[0], 1, 14, 14) elif args.method == 'attn_gradcam': Res = baselines.generate_cam_attn(data, index=index).reshape(data.shape[0], 1, 14, 14) if args.method != 'full_lrp' and args.method != 'input_grads': Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda() Res = (Res - Res.min()) / (Res.max() - Res.min()) data_cam[-data.shape[0]:] = Res.data.cpu().numpy() if __name__ == "__main__": parser = argparse.ArgumentParser(description='Train a segmentation') parser.add_argument('--batch-size', type=int, default=1, help='') parser.add_argument('--method', type=str, default='grad_rollout', choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer', 'attn_last_layer', 'attn_gradcam'], help='') parser.add_argument('--lmd', type=float, default=10, help='') parser.add_argument('--vis-class', type=str, default='top', choices=['top', 'target', 'index'], help='') parser.add_argument('--class-id', type=int, default=0, help='') parser.add_argument('--cls-agn', action='store_true', default=False, help='') parser.add_argument('--no-ia', action='store_true', default=False, help='') parser.add_argument('--no-fx', action='store_true', default=False, help='') parser.add_argument('--no-fgx', action='store_true', default=False, help='') parser.add_argument('--no-m', action='store_true', default=False, help='') parser.add_argument('--no-reg', action='store_true', default=False, help='') parser.add_argument('--is-ablation', type=bool, default=False, help='') parser.add_argument('--imagenet-validation-path', type=str, required=True, help='') args = parser.parse_args() # PATH variables PATH = os.path.dirname(os.path.abspath(__file__)) + '/' os.makedirs(os.path.join(PATH, 'visualizations'), exist_ok=True) try: os.remove(os.path.join(PATH, 'visualizations/{}/{}/results.hdf5'.format(args.method, args.vis_class))) except OSError: pass os.makedirs(os.path.join(PATH, 'visualizations/{}'.format(args.method)), exist_ok=True) if args.vis_class == 'index': os.makedirs(os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method, args.vis_class, args.class_id)), exist_ok=True) args.method_dir = os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method, args.vis_class, args.class_id)) else: ablation_fold = 'ablation' if args.is_ablation else 'not_ablation' os.makedirs(os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method, args.vis_class, ablation_fold)), exist_ok=True) args.method_dir = os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method, args.vis_class, ablation_fold)) cuda = torch.cuda.is_available() device = torch.device("cuda" if cuda else "cpu") # Model model = vit_base_patch16_224(pretrained=True).cuda() baselines = Baselines(model) # LRP model_LRP = vit_LRP(pretrained=True).cuda() model_LRP.eval() lrp = LRP(model_LRP) # orig LRP model_orig_LRP = vit_orig_LRP(pretrained=True).cuda() model_orig_LRP.eval() orig_lrp = LRP(model_orig_LRP) # Dataset loader for sample images transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) imagenet_ds = ImageNet(args.imagenet_validation_path, split='val', download=False, transform=transform) sample_loader = torch.utils.data.DataLoader( imagenet_ds, batch_size=args.batch_size, shuffle=False, num_workers=4 ) compute_saliency_and_save(args)