explain-ViT / ViT_DeiT /baselines /ViT /generate_visualizations.py
WwYc's picture
Upload 707 files
3d27aee verified
import os
from tqdm import tqdm
import h5py
import argparse
# Import saliency methods and models
from misc_functions import *
from ViT_explanation_generator import Baselines, LRP
from ViT_new import vit_base_patch16_224
from ViT_LRP import vit_base_patch16_224 as vit_LRP
from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
from torchvision.datasets import ImageNet
def normalize(tensor,
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
return tensor
def compute_saliency_and_save(args):
first = True
with h5py.File(os.path.join(args.method_dir, 'results.hdf5'), 'a') as f:
data_cam = f.create_dataset('vis',
(1, 1, 224, 224),
maxshape=(None, 1, 224, 224),
dtype=np.float32,
compression="gzip")
data_image = f.create_dataset('image',
(1, 3, 224, 224),
maxshape=(None, 3, 224, 224),
dtype=np.float32,
compression="gzip")
data_target = f.create_dataset('target',
(1,),
maxshape=(None,),
dtype=np.int32,
compression="gzip")
for batch_idx, (data, target) in enumerate(tqdm(sample_loader)):
if first:
first = False
data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0)
data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0)
data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0)
else:
data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0)
data_image.resize(data_image.shape[0] + data.shape[0], axis=0)
data_target.resize(data_target.shape[0] + data.shape[0], axis=0)
# Add data
data_image[-data.shape[0]:] = data.data.cpu().numpy()
data_target[-data.shape[0]:] = target.data.cpu().numpy()
target = target.to(device)
data = normalize(data)
data = data.to(device)
data.requires_grad_()
index = None
if args.vis_class == 'target':
index = target
if args.method == 'rollout':
Res = baselines.generate_rollout(data, start_layer=1).reshape(data.shape[0], 1, 14, 14)
# Res = Res - Res.mean()
elif args.method == 'lrp':
Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape(data.shape[0], 1, 14, 14)
# Res = Res - Res.mean()
elif args.method == 'transformer_attribution':
Res = lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14)
# Res = Res - Res.mean()
elif args.method == 'full_lrp':
Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape(data.shape[0], 1, 224, 224)
# Res = Res - Res.mean()
elif args.method == 'lrp_last_layer':
Res = orig_lrp.generate_LRP(data, method="last_layer", is_ablation=args.is_ablation, index=index) \
.reshape(data.shape[0], 1, 14, 14)
# Res = Res - Res.mean()
elif args.method == 'attn_last_layer':
Res = lrp.generate_LRP(data, method="last_layer_attn", is_ablation=args.is_ablation) \
.reshape(data.shape[0], 1, 14, 14)
elif args.method == 'attn_gradcam':
Res = baselines.generate_cam_attn(data, index=index).reshape(data.shape[0], 1, 14, 14)
if args.method != 'full_lrp' and args.method != 'input_grads':
Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
Res = (Res - Res.min()) / (Res.max() - Res.min())
data_cam[-data.shape[0]:] = Res.data.cpu().numpy()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train a segmentation')
parser.add_argument('--batch-size', type=int,
default=1,
help='')
parser.add_argument('--method', type=str,
default='grad_rollout',
choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer',
'attn_last_layer', 'attn_gradcam'],
help='')
parser.add_argument('--lmd', type=float,
default=10,
help='')
parser.add_argument('--vis-class', type=str,
default='top',
choices=['top', 'target', 'index'],
help='')
parser.add_argument('--class-id', type=int,
default=0,
help='')
parser.add_argument('--cls-agn', action='store_true',
default=False,
help='')
parser.add_argument('--no-ia', action='store_true',
default=False,
help='')
parser.add_argument('--no-fx', action='store_true',
default=False,
help='')
parser.add_argument('--no-fgx', action='store_true',
default=False,
help='')
parser.add_argument('--no-m', action='store_true',
default=False,
help='')
parser.add_argument('--no-reg', action='store_true',
default=False,
help='')
parser.add_argument('--is-ablation', type=bool,
default=False,
help='')
parser.add_argument('--imagenet-validation-path', type=str,
required=True,
help='')
args = parser.parse_args()
# PATH variables
PATH = os.path.dirname(os.path.abspath(__file__)) + '/'
os.makedirs(os.path.join(PATH, 'visualizations'), exist_ok=True)
try:
os.remove(os.path.join(PATH, 'visualizations/{}/{}/results.hdf5'.format(args.method,
args.vis_class)))
except OSError:
pass
os.makedirs(os.path.join(PATH, 'visualizations/{}'.format(args.method)), exist_ok=True)
if args.vis_class == 'index':
os.makedirs(os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
args.vis_class,
args.class_id)), exist_ok=True)
args.method_dir = os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
args.vis_class,
args.class_id))
else:
ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
os.makedirs(os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
args.vis_class, ablation_fold)), exist_ok=True)
args.method_dir = os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
args.vis_class, ablation_fold))
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
# Model
model = vit_base_patch16_224(pretrained=True).cuda()
baselines = Baselines(model)
# LRP
model_LRP = vit_LRP(pretrained=True).cuda()
model_LRP.eval()
lrp = LRP(model_LRP)
# orig LRP
model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
model_orig_LRP.eval()
orig_lrp = LRP(model_orig_LRP)
# Dataset loader for sample images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
imagenet_ds = ImageNet(args.imagenet_validation_path, split='val', download=False, transform=transform)
sample_loader = torch.utils.data.DataLoader(
imagenet_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=4
)
compute_saliency_and_save(args)