sanjudebnath commited on
Commit
bf2494e
·
verified ·
1 Parent(s): 41b56dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import torch
4
+ from transformers import DistilBertTokenizer, DistilBertForMaskedLM
5
+
6
+ from qa_model import ReuseQuestionDistilBERT
7
+
8
+ @st.cache(allow_output_mutation=True)
9
+ def load_model():
10
+ mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert
11
+ m = ReuseQuestionDistilBERT(mod)
12
+ m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu')))
13
+ model = m
14
+ del mod
15
+ del m
16
+ tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer')
17
+ return model, tokenizer
18
+
19
+
20
+ def get_answer(question, text, tokenizer, model):
21
+ question = [question.strip()]
22
+ text = [text.strip()]
23
+
24
+ inputs = tokenizer(
25
+ question,
26
+ text,
27
+ max_length=512,
28
+ truncation="only_second",
29
+ padding="max_length",
30
+ )
31
+ input_ids = torch.tensor(inputs['input_ids'])
32
+ outputs = model(input_ids, attention_mask=torch.tensor(inputs['attention_mask']), start_positions=None, end_positions=None)
33
+
34
+ start = torch.argmax(outputs['start_logits'])
35
+ end = torch.argmax(outputs['end_logits'])
36
+
37
+ ans_tokens = input_ids[0][start: end + 1]
38
+
39
+ answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
40
+ predicted = tokenizer.convert_tokens_to_string(answer_tokens)
41
+ return predicted
42
+
43
+
44
+ def main():
45
+ st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")
46
+
47
+ st.write("# Question Answering Tool \n"
48
+ "This tool will help you find answers to your questions about the text you provide. \n"
49
+ "Please enter your question and the text you want to search in the boxes below.")
50
+ model, tokenizer = load_model()
51
+
52
+ with st.form("qa_form"):
53
+ # define a streamlit textarea
54
+ text = st.text_area("Enter your text here", on_change=None)
55
+
56
+ # define a streamlit input
57
+ question = st.text_input("Enter your question here")
58
+
59
+ if st.form_submit_button("Submit"):
60
+ data_load_state = st.text('Let me think about that...')
61
+ # call the function to get the answer
62
+ answer = get_answer(question, text, tokenizer, model)
63
+ # display the answer
64
+ if answer == "":
65
+ data_load_state.text("Sorry but I don't know the answer to that question")
66
+ else:
67
+ data_load_state.text(answer)
68
+
69
+
70
+ main()