Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import torch | |
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering | |
st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:") | |
def load_model(): | |
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad") | |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad") | |
return model, tokenizer | |
def get_answer(question, text, tokenizer, model): | |
if "your name" in question.lower(): | |
return "My name is Numini, full form NativUttarMini, created by Sanju Debnath at University of Calcutta." | |
inputs = tokenizer(question, text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
start = torch.argmax(outputs.start_logits) | |
end = torch.argmax(outputs.end_logits) + 1 | |
ans_tokens = inputs.input_ids[0][start:end] | |
answer = tokenizer.decode(ans_tokens, skip_special_tokens=True) | |
return answer | |
def main(): | |
st.write("# Question Answering Tool \n" | |
"This tool will help you find answers to your questions about the text you provide. \n" | |
"Please enter your question and the text you want to search in the boxes below.") | |
model, tokenizer = load_model() | |
with st.form("qa_form"): | |
text = st.text_area("Enter your text here") | |
question = st.text_input("Enter your question here") | |
if st.form_submit_button("Submit"): | |
data_load_state = st.text('Let me think about that...') | |
answer = get_answer(question, text, tokenizer, model) | |
if answer.strip() == "": | |
data_load_state.text("Sorry but I don't know the answer to that question") | |
else: | |
data_load_state.text(answer) | |
main() | |