File size: 1,482 Bytes
a7f5ca7
 
 
 
66e49d4
 
a7f5ca7
 
 
835e1e2
a7f5ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
835e1e2
a7f5ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
1810440
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gradio
from transformers import ViltProcessor, ViltForQuestionAnswering
from PIL import Image

processor = ViltProcessor.from_pretrained("vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("vilt-b32-finetuned-vqa")


def predict_answer(image, question):
    image = Image.fromarray(image.astype('uint8'), 'RGB')
    encoding = processor(image, question, return_tensors="pt")
    outputs = model(**encoding)
    logits = outputs.logits
    probs = logits.softmax(dim=-1)
    sorted_probs, sorted_indices = probs[0].sort(descending=True)
    answer_list = []
    for i in range(5):
        prob = sorted_probs[i].item()
        if prob > 0.00:
            idx = sorted_indices[i].item()
            answer = model.config.id2label[idx]
            answer_list.append(f"{answer}: {prob:.2%}")
    return answer_list


inputs = [
    gradio.components.Image(label="Image"),
    gradio.components.Textbox(label="Question", placeholder="Enter your question here.")
]

outputs = [
    gradio.components.Textbox(label="Answer 1"),
    gradio.components.Textbox(label="Answer 2"),
    gradio.components.Textbox(label="Answer 3"),
    gradio.components.Textbox(label="Answer 4"),
    gradio.components.Textbox(label="Answer 5")
]

title = "Visual Question Answering (vilt-b32-finetuned-vqa)"

gradio.Interface(fn=predict_answer, inputs=inputs, outputs=outputs, title=title, allow_flagging="never",
                 css="footer{display:none !important}").launch()