sanjudebnath commited on
Commit
7fdbc52
Β·
verified Β·
1 Parent(s): 361c536

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -32
app.py CHANGED
@@ -6,44 +6,34 @@ st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:"
6
 
7
  @st.cache_resource
8
  def load_model():
 
9
  model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")
10
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
11
  return model, tokenizer
12
 
13
- def generate_prompt(question, text):
14
- """Enhance the input prompt to guide the model better."""
15
- return (
16
- f"Context: {text}\n\n"
17
- f"Instruction: Read the above context carefully and extract the most relevant answer.\n"
18
- f"Question: {question}\n"
19
- f"Answer:"
20
- )
21
-
22
  def get_answer(question, text, tokenizer, model):
23
- # Special case for bot identity
24
  if any(phrase in question.lower() for phrase in ["your name", "who are you", "about you"]):
25
  return "I am Numini, NativUttarMini, created by Sanju Debnath at University of Calcutta."
26
 
27
- # Enhance the input for better response
28
- prompt_text = generate_prompt(question, text)
29
-
30
- # Tokenize with truncation for better handling of large text
31
- inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, padding=True, max_length=512)
32
 
33
  with torch.no_grad():
34
  outputs = model(**inputs)
35
 
36
- start = torch.argmax(outputs.start_logits)
37
- end = torch.argmax(outputs.end_logits) + 1
38
 
39
- # Handle cases where model is uncertain
40
- if start >= end:
41
  return "I couldn't find a clear answer in the given text."
42
 
43
- answer = tokenizer.decode(inputs.input_ids[0][start:end], skip_special_tokens=True)
 
44
 
45
- # Post-process answer for better readability
46
- if len(answer.split()) < 3: # Model sometimes returns incomplete answers
47
  return "I'm not sure about the exact answer. Can you try rephrasing the question?"
48
 
49
  return answer
@@ -58,15 +48,17 @@ def main():
58
  text = st.text_area("πŸ“œ Enter the text/document:", height=200)
59
  question = st.text_input("❓ Enter your question:")
60
 
61
- if st.form_submit_button("πŸ” Get Answer"):
62
- if not text.strip():
63
- st.warning("⚠️ Please enter some text to analyze.")
64
- elif not question.strip():
65
- st.warning("⚠️ Please enter a question.")
66
- else:
67
- st.text("πŸ€– Thinking...")
 
 
68
  answer = get_answer(question, text, tokenizer, model)
69
- st.success(f"βœ… Answer: {answer}")
70
-
71
- main()
72
 
 
 
 
6
 
7
  @st.cache_resource
8
  def load_model():
9
+ """Loads the DistilBERT model and tokenizer for QA."""
10
  model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")
11
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
12
  return model, tokenizer
13
 
 
 
 
 
 
 
 
 
 
14
  def get_answer(question, text, tokenizer, model):
15
+ """Extracts the most relevant answer from the given text."""
16
  if any(phrase in question.lower() for phrase in ["your name", "who are you", "about you"]):
17
  return "I am Numini, NativUttarMini, created by Sanju Debnath at University of Calcutta."
18
 
19
+ # Tokenize input text and question
20
+ inputs = tokenizer(question, text, return_tensors="pt", truncation=True, padding=True, max_length=512)
 
 
 
21
 
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
 
25
+ start_idx = torch.argmax(outputs.start_logits)
26
+ end_idx = torch.argmax(outputs.end_logits) + 1
27
 
28
+ # Validate extracted indices
29
+ if start_idx >= end_idx or end_idx > inputs.input_ids.shape[1]:
30
  return "I couldn't find a clear answer in the given text."
31
 
32
+ # Decode extracted answer
33
+ answer = tokenizer.decode(inputs.input_ids[0][start_idx:end_idx], skip_special_tokens=True)
34
 
35
+ # Ensure answer is meaningful
36
+ if len(answer.split()) < 2:
37
  return "I'm not sure about the exact answer. Can you try rephrasing the question?"
38
 
39
  return answer
 
48
  text = st.text_area("πŸ“œ Enter the text/document:", height=200)
49
  question = st.text_input("❓ Enter your question:")
50
 
51
+ submit = st.form_submit_button("πŸ” Get Answer")
52
+
53
+ if submit:
54
+ if not text.strip():
55
+ st.warning("⚠️ Please enter some text to analyze.")
56
+ elif not question.strip():
57
+ st.warning("⚠️ Please enter a question.")
58
+ else:
59
+ with st.spinner("πŸ€– Thinking..."):
60
  answer = get_answer(question, text, tokenizer, model)
61
+ st.success(f"βœ… Answer: {answer}")
 
 
62
 
63
+ if __name__ == "__main__":
64
+ main()