sanjudebnath commited on
Commit
aad2b57
·
verified ·
1 Parent(s): 0542b04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -44
app.py CHANGED
@@ -1,70 +1,42 @@
1
  import streamlit as st
2
  import numpy as np
3
  import torch
4
- from transformers import DistilBertTokenizer, DistilBertForMaskedLM
5
-
6
- from qa_model import ReuseQuestionDistilBERT
7
 
8
  @st.cache(allow_output_mutation=True)
9
  def load_model():
10
- mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert
11
- m = ReuseQuestionDistilBERT(mod)
12
- m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu')))
13
- model = m
14
- del mod
15
- del m
16
- tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer')
17
  return model, tokenizer
18
 
19
-
20
  def get_answer(question, text, tokenizer, model):
21
- question = [question.strip()]
22
- text = [text.strip()]
23
-
24
- inputs = tokenizer(
25
- question,
26
- text,
27
- max_length=512,
28
- truncation="only_second",
29
- padding="max_length",
30
- )
31
- input_ids = torch.tensor(inputs['input_ids'])
32
- outputs = model(input_ids, attention_mask=torch.tensor(inputs['attention_mask']), start_positions=None, end_positions=None)
33
-
34
- start = torch.argmax(outputs['start_logits'])
35
- end = torch.argmax(outputs['end_logits'])
36
-
37
- ans_tokens = input_ids[0][start: end + 1]
38
-
39
- answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
40
- predicted = tokenizer.convert_tokens_to_string(answer_tokens)
41
- return predicted
42
-
43
 
44
  def main():
45
  st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")
46
 
47
  st.write("# Question Answering Tool \n"
48
- "This tool will help you find answers to your questions about the text you provide. \n"
49
- "Please enter your question and the text you want to search in the boxes below.")
50
  model, tokenizer = load_model()
51
 
52
  with st.form("qa_form"):
53
- # define a streamlit textarea
54
- text = st.text_area("Enter your text here", on_change=None)
55
-
56
- # define a streamlit input
57
  question = st.text_input("Enter your question here")
58
 
59
  if st.form_submit_button("Submit"):
60
  data_load_state = st.text('Let me think about that...')
61
- # call the function to get the answer
62
  answer = get_answer(question, text, tokenizer, model)
63
- # display the answer
64
- if answer == "":
65
  data_load_state.text("Sorry but I don't know the answer to that question")
66
  else:
67
  data_load_state.text(answer)
68
 
69
-
70
- main()
 
1
  import streamlit as st
2
  import numpy as np
3
  import torch
4
+ from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
 
 
5
 
6
  @st.cache(allow_output_mutation=True)
7
  def load_model():
8
+ model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")
9
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
 
 
 
 
 
10
  return model, tokenizer
11
 
 
12
  def get_answer(question, text, tokenizer, model):
13
+ inputs = tokenizer(question, text, return_tensors="pt", truncation=True, padding=True)
14
+ with torch.no_grad():
15
+ outputs = model(**inputs)
16
+ start = torch.argmax(outputs.start_logits)
17
+ end = torch.argmax(outputs.end_logits) + 1
18
+ ans_tokens = inputs.input_ids[0][start:end]
19
+ answer = tokenizer.decode(ans_tokens, skip_special_tokens=True)
20
+ return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def main():
23
  st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")
24
 
25
  st.write("# Question Answering Tool \n"
26
+ "This tool will help you find answers to your questions about the text you provide. \n"
27
+ "Please enter your question and the text you want to search in the boxes below.")
28
  model, tokenizer = load_model()
29
 
30
  with st.form("qa_form"):
31
+ text = st.text_area("Enter your text here")
 
 
 
32
  question = st.text_input("Enter your question here")
33
 
34
  if st.form_submit_button("Submit"):
35
  data_load_state = st.text('Let me think about that...')
 
36
  answer = get_answer(question, text, tokenizer, model)
37
+ if answer.strip() == "":
 
38
  data_load_state.text("Sorry but I don't know the answer to that question")
39
  else:
40
  data_load_state.text(answer)
41
 
42
+ main()