Spaces:
Runtime error
Runtime error
import json | |
import torch | |
from PIL import Image | |
from ruamel import yaml | |
from model import albef_model_for_vqa | |
from data.transforms import ALBEFTextTransform, testing_image_transform | |
import gradio as gr | |
data_dir = "./" | |
config = yaml.load(open("./configs/vqa.yaml", "r"), Loader=yaml.Loader) | |
model = albef_model_for_vqa(config) | |
checkpoint_url = "https://download.pytorch.org/models/multimodal/albef/finetuned_vqa_checkpoint.pt" | |
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu') | |
model.load_state_dict(checkpoint) | |
image_transform = testing_image_transform() | |
question_transform = ALBEFTextTransform(add_end_token=False) | |
answer_transform = ALBEFTextTransform(do_pre_process=False) | |
vqa_data = json.load(open(data_dir + "vqa_data.json", "r")) | |
answer_list = json.load(open(data_dir + "answer_list.json", "r")) | |
examples = [[data['image'], data['question']] for data in vqa_data] | |
title = 'VQA with ALBEF' | |
description = 'VQA with [ALBEF](https://arxiv.org/abs/2107.07651), adapted from the [torchmultimodal example notebook](https://github.com/facebookresearch/multimodal/blob/main/examples/albef/vqa_with_albef.ipynb).' | |
article = '''```bibtex | |
@article{li2021align, | |
title={Align before fuse: Vision and language representation learning with momentum distillation}, | |
author={Li, Junnan and Selvaraju, Ramprasaath and Gotmare, Akhilesh and Joty, Shafiq and Xiong, Caiming and Hoi, Steven Chu Hong}, | |
journal={Advances in neural information processing systems}, | |
volume={34}, | |
pages={9694--9705}, | |
year={2021} | |
} | |
```''' | |
def infer(image, question): | |
images = [image] | |
image_input = [image_transform(image) for image in images] | |
image_input = torch.stack(image_input, dim=0) | |
question_input = question_transform([question]) | |
question_atts = (question_input != 0).type(torch.long) | |
answer_input = answer_transform(answer_list) | |
answer_atts = (answer_input != 0).type(torch.long) | |
answer_ids, _ = model( | |
image_input, | |
question_input, | |
question_atts, | |
answer_input, | |
answer_atts, | |
k=1, | |
is_train=False, | |
) | |
predicted_answer_id = answer_ids[0] | |
predicted_answer = answer_list[predicted_answer_id] | |
return predicted_answer | |
demo = gr.Interface( | |
fn=infer, | |
inputs=[gr.Image(label='image', type='pil', image_mode='RGB'), gr.Text(label='question')], | |
outputs=gr.Text(label='answer'), | |
examples=examples, | |
title=title, | |
description=description, | |
article=article | |
) | |
demo.launch() |