explain-LXMERT / lxmert /perturbation.py
WwYc's picture
Upload 61 files
08d7644 verified
raw
history blame
11.5 kB
from lxmert.lxmert.src.tasks import vqa_data
from lxmert.lxmert.src.modeling_frcnn import GeneralizedRCNN
import lxmert.lxmert.src.vqa_utils as utils
from lxmert.lxmert.src.processing_image import Preprocess
from transformers import LxmertTokenizer
from lxmert.lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
from lxmert.lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP
from tqdm import tqdm
from lxmert.lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
import random
from lxmert.lxmert.src.param import args
OBJ_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt"
ATTR_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt"
VQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json"
class ModelPert:
def __init__(self, COCO_val_path, use_lrp=False):
self.COCO_VAL_PATH = COCO_val_path
self.vqa_answers = utils.get_data(VQA_URL)
# load models and model components
self.frcnn_cfg = utils.Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
self.frcnn_cfg.MODEL.DEVICE = "cuda"
self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg)
self.image_preprocess = Preprocess(self.frcnn_cfg)
self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
if use_lrp:
self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
else:
self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
self.lxmert_vqa.eval()
self.model = self.lxmert_vqa
self.vqa_dataset = vqa_data.VQADataset(splits="valid")
self.pert_steps = [0, 0.25, 0.5, 0.75, 0.8, 0.85, 0.9, 0.95, 1]
self.pert_acc = [0] * len(self.pert_steps)
def forward(self, item):
image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
self.image_file_path = image_file_path
self.image_id = item['img_id']
# run frcnn
images, sizes, scales_yx = self.image_preprocess(image_file_path)
output_dict = self.frcnn(
images,
sizes,
scales_yx=scales_yx,
padding="max_detections",
max_detections= self.frcnn_cfg.max_detections,
return_tensors="pt"
)
inputs = self.lxmert_tokenizer(
item['sent'],
truncation=True,
return_token_type_ids=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
self.text_len = len(self.question_tokens)
# Very important that the boxes are normalized
normalized_boxes = output_dict.get("normalized_boxes")
features = output_dict.get("roi_features")
self.image_boxes_len = features.shape[1]
self.bboxes = output_dict.get("boxes")
self.output = self.lxmert_vqa(
input_ids=inputs.input_ids.to("cuda"),
attention_mask=inputs.attention_mask.to("cuda"),
visual_feats=features.to("cuda"),
visual_pos=normalized_boxes.to("cuda"),
token_type_ids=inputs.token_type_ids.to("cuda"),
return_dict=True,
output_attentions=False,
)
return self.output
def perturbation_image(self, item, cam_image, cam_text, is_positive_pert=False):
if is_positive_pert:
cam_image = cam_image * (-1)
image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
# run frcnn
images, sizes, scales_yx = self.image_preprocess(image_file_path)
output_dict = self.frcnn(
images,
sizes,
scales_yx=scales_yx,
padding="max_detections",
max_detections=self.frcnn_cfg.max_detections,
return_tensors="pt"
)
inputs = self.lxmert_tokenizer(
item['sent'],
truncation=True,
return_token_type_ids=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
# Very important that the boxes are normalized
normalized_boxes = output_dict.get("normalized_boxes")
features = output_dict.get("roi_features")
for step_idx, step in enumerate(self.pert_steps):
# find top step boxes
curr_num_boxes = int((1 - step) * self.image_boxes_len)
_, top_bboxes_indices = cam_image.topk(k=curr_num_boxes, dim=-1)
top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
curr_features = features[:, top_bboxes_indices, :]
curr_pos = normalized_boxes[:, top_bboxes_indices, :]
output = self.lxmert_vqa(
input_ids=inputs.input_ids.to("cuda"),
attention_mask=inputs.attention_mask.to("cuda"),
visual_feats=curr_features.to("cuda"),
visual_pos=curr_pos.to("cuda"),
token_type_ids=inputs.token_type_ids.to("cuda"),
return_dict=True,
output_attentions=False,
)
answer = self.vqa_answers[output.question_answering_score.argmax()]
accuracy = item["label"].get(answer, 0)
self.pert_acc[step_idx] += accuracy
return self.pert_acc
def perturbation_text(self, item, cam_image, cam_text, is_positive_pert=False):
if is_positive_pert:
cam_text = cam_text * (-1)
image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
# run frcnn
images, sizes, scales_yx = self.image_preprocess(image_file_path)
output_dict = self.frcnn(
images,
sizes,
scales_yx=scales_yx,
padding="max_detections",
max_detections=self.frcnn_cfg.max_detections,
return_tensors="pt"
)
inputs = self.lxmert_tokenizer(
item['sent'],
truncation=True,
return_token_type_ids=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
# Very important that the boxes are normalized
normalized_boxes = output_dict.get("normalized_boxes")
features = output_dict.get("roi_features")
for step_idx, step in enumerate(self.pert_steps):
# we must keep the [CLS] token in order to have the classification
# we also keep the [SEP] token
cam_pure_text = cam_text[1:-1]
text_len = cam_pure_text.shape[0]
# find top step tokens, without the [CLS] token and the [SEP] token
curr_num_tokens = int((1 - step) * text_len)
_, top_bboxes_indices = cam_pure_text.topk(k=curr_num_tokens, dim=-1)
top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
# add back [CLS], [SEP] tokens
top_bboxes_indices = [0, cam_text.shape[0] - 1] +\
[top_bboxes_indices[i] + 1 for i in range(len(top_bboxes_indices))]
# text tokens must be sorted for positional embedding to work
top_bboxes_indices = sorted(top_bboxes_indices)
curr_input_ids = inputs.input_ids[:, top_bboxes_indices]
curr_attention_mask = inputs.attention_mask[:, top_bboxes_indices]
curr_token_ids = inputs.token_type_ids[:, top_bboxes_indices]
output = self.lxmert_vqa(
input_ids=curr_input_ids.to("cuda"),
attention_mask=curr_attention_mask.to("cuda"),
visual_feats=features.to("cuda"),
visual_pos=normalized_boxes.to("cuda"),
token_type_ids=curr_token_ids.to("cuda"),
return_dict=True,
output_attentions=False,
)
answer = self.vqa_answers[output.question_answering_score.argmax()]
accuracy = item["label"].get(answer, 0)
self.pert_acc[step_idx] += accuracy
return self.pert_acc
def main(args):
model_pert = ModelPert(args.COCO_path, use_lrp=True)
ours = GeneratorOurs(model_pert)
baselines = GeneratorBaselines(model_pert)
oursNoAggAblation = GeneratorOursAblationNoAggregation(model_pert)
vqa_dataset = vqa_data.VQADataset(splits="valid")
vqa_answers = utils.get_data(VQA_URL)
method_name = args.method
items = vqa_dataset.data
random.seed(1234)
r = list(range(len(items)))
random.shuffle(r)
pert_samples_indices = r[:args.num_samples]
iterator = tqdm([vqa_dataset.data[i] for i in pert_samples_indices])
test_type = "positive" if args.is_positive_pert else "negative"
modality = "text" if args.is_text_pert else "image"
print("runnig {0} pert test for {1} modality with method {2}".format(test_type, modality, args.method))
for index, item in enumerate(iterator):
if method_name == 'transformer_att':
R_t_t, R_t_i = baselines.generate_transformer_attr(item)
elif method_name == 'attn_gradcam':
R_t_t, R_t_i = baselines.generate_attn_gradcam(item)
elif method_name == 'partial_lrp':
R_t_t, R_t_i = baselines.generate_partial_lrp(item)
elif method_name == 'raw_attn':
R_t_t, R_t_i = baselines.generate_raw_attn(item)
elif method_name == 'rollout':
R_t_t, R_t_i = baselines.generate_rollout(item)
elif method_name == "ours_with_lrp_no_normalization":
R_t_t, R_t_i = ours.generate_ours(item, normalize_self_attention=False)
elif method_name == "ours_no_lrp":
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False)
elif method_name == "ours_no_lrp_no_norm":
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, normalize_self_attention=False)
elif method_name == "ours_with_lrp":
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=True)
elif method_name == "ablation_no_self_in_10":
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, apply_self_in_rule_10=False)
elif method_name == "ablation_no_aggregation":
R_t_t, R_t_i = oursNoAggAblation.generate_ours_no_agg(item, use_lrp=False, normalize_self_attention=False)
else:
print("Please enter a valid method name")
return
cam_image = R_t_i[0]
cam_text = R_t_t[0]
cam_image = (cam_image - cam_image.min()) / (cam_image.max() - cam_image.min())
cam_text = (cam_text - cam_text.min()) / (cam_text.max() - cam_text.min())
if args.is_text_pert:
curr_pert_result = model_pert.perturbation_text(item, cam_image, cam_text, args.is_positive_pert)
else:
curr_pert_result = model_pert.perturbation_image(item, cam_image, cam_text, args.is_positive_pert)
curr_pert_result = [round(res / (index+1) * 100, 2) for res in curr_pert_result]
iterator.set_description("Acc: {}".format(curr_pert_result))
if __name__ == "__main__":
main(args)